1

I have a PyQt application which has a tab widget that can open any number of tabs. Each tab embeds a matplotlib canvas displaying graphs.

Lately, I have tried to implement InteractiveGraph from netgraph library, with little success despite the help of other stackoverflow similar topic. Maybe it comes from the additional presence of tabs, I don't know.

What I observe is that I can't manage to click on the graph nodes. Graph are properly displayed though.

Below is a quick example of my code (tabs have static graph values for testing, each added tab add an embedded graph), and how I tried to implement proposed solution of similar topic. I was not sure about necessity of using mpl_connect, so I tried with and without it and it didn't change anything.

import sys
from PyQt5 import QtCore, QtWidgets
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.pyplot import Figure
import networkx as nx
import numpy as np
from netgraph import InteractiveGraph


class data_tab(QtWidgets.QWidget):

    def __init__(self, parent, title):

        QtWidgets.QWidget.__init__(self, parent)

        self.data_tab_glayout = QtWidgets.QGridLayout(self)
        self.canvas = FigureCanvas(Figure(figsize=(5, 3)))
        self.canvas.setParent(parent)

        self.canvas.setFocusPolicy(QtCore.Qt.ClickFocus)
        self.canvas.setFocus()

        self.canvas_vlayout = QtWidgets.QVBoxLayout(self.canvas)
        self.data_tab_glayout.addWidget(self.canvas, 0, 0, 2, 1)

        self.canvas.mpl_connect('key_press_event', self.on_key_press)

        self.axe = self.canvas.figure.add_subplot(111)
        self.canvas.figure.subplots_adjust(left=0.025, top=0.965, bottom=0.040, right=0.975)
        # add the tab to the parent
        parent.addTab(self, "")

        # set text name
        parent.setTabText(parent.indexOf(self), title)


    def on_key_press(self, event):
        print("you press", event.key)


class MyApp(QtWidgets.QMainWindow):

    def __init__(self, parent=None):
        QtWidgets.QMainWindow.__init__(self, parent)

        self.showMaximized()

        self.centralwidget = QtWidgets.QWidget(self)
        self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
        self.core_tab = QtWidgets.QTabWidget(self.centralwidget)
        self.verticalLayout.addWidget(self.core_tab)
        self.add_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.add_tab_btn)
        self.refresh_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.refresh_tab_btn)
        self.setCentralWidget(self.centralwidget)

        self.add_tab_btn.setText("Add Tab")
        self.refresh_tab_btn.setText("Refresh Tabs")

        self.core_tab.setEnabled(True)
        self.core_tab.setTabShape(QtWidgets.QTabWidget.Rounded)
        self.core_tab.setElideMode(QtCore.Qt.ElideNone)
        self.core_tab.setDocumentMode(False)
        self.core_tab.setTabsClosable(True)
        self.core_tab.setMovable(True)
        self.core_tab.setTabBarAutoHide(False)

        self.tab_counter = 0

        self.random_tabs = [("a", ["b", "c"]),
                            ("d", ["e", "f", "g"]),
                            ("h", ["i", "j", "k", "l"]),
                            ("m", ["n"]),
                            ("o", ["p", "q"]),
                            ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])]

        self.add_tab_btn.clicked.connect(self.openRandomTab)
        self.refresh_tab_btn.clicked.connect(self.refreshAllTabs)

    def openRandomTab(self):

        tab = data_tab(self.core_tab, "test " + str(self.tab_counter))
        self._drawDataGraph(self.tab_counter % len(self.random_tabs), tab)
        self.tab_counter += 1

        self.core_tab.setCurrentIndex(self.core_tab.indexOf(tab))


    def _drawDataGraph(self, tabNb, dataWidget):
        dataWidget.axe.cla()

        # 1. draw graph
        producer = self.random_tabs[tabNb][0]
        consumers = self.random_tabs[tabNb][1]

        color_map = []
        DG = nx.DiGraph()
        for i, cons in enumerate(consumers):
            DG.add_edge(producer, cons, label=f"edge-{i}")

        node_color = dict()
        for node in DG:
            if node in producer:
                node_color[node] = "#DCE46F"
            else:
                node_color[node] = "#6FA2E4"
        pos = nx.shell_layout(DG)
        pos[producer] = pos[producer] + np.array([0.2, 0])
        labels = nx.get_edge_attributes(DG, 'label')

        graph_instance = InteractiveGraph(DG, node_layout=pos, edge_layout='curved', origin=(-1, -1), scale=(2, 2),
                                          node_color=node_color, node_size=8.,
                                          node_labels=True, node_label_fontdict=dict(size=10),
                                          edge_labels=labels, edge_label_fontdict=dict(size=10), ax=dataWidget.axe
                                          )

        dataWidget.canvas.draw()


    def refreshAllTabs(self):

        # loop through all pages and associated to get
        for tab_index in range(self.core_tab.count()):
            data_tab_widget = self.core_tab.widget(tab_index)

            # draw graph
            self._drawDataGraph(tab_index % len(self.random_tabs), data_tab_widget)




sys.argv = ['']
app = QtWidgets.QApplication(sys.argv)
main_app = MyApp()
main_app.show()
app.exec_()

2 Answers 2

2

If you what click on the figure, I think you need to connect mpl_connect on the figure, not on the canvas

self.figure = Figure(facecolor='white')
self.canvas = FigureCanvas(self.figure)

then make the connection with the 'button_press_event'

self.figure.canvas.mpl_connect('button_press_event', self.onclick)

If you want get the node that was clicked on, you can put DG and graph_instance in the data_tab class

    dataWidget.DG = DG
    dataWidget.graph_instance =graph_instance

then use them in the callback function

def onclick(self,event):
    print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
          ('double' if event.dblclick else 'single', event.button,
           event.x, event.y, event.xdata, event.ydata))
    x = event.xdata
    y = event.ydata
    for n in self.graph_instance.node_artists :
        node = self.graph_instance.node_artists [n]
        dist = ((x-node.xy[0])**2 + (y-node.xy[1])**2)**0.5
        if dist < node.radius:
            print(node)

the function could be something like

import sys
from PyQt5 import QtCore, QtWidgets
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.pyplot import Figure
import networkx as nx
import numpy as np
from netgraph import InteractiveGraph


class data_tab(QtWidgets.QWidget):

    def __init__(self, parent, title):

        QtWidgets.QWidget.__init__(self, parent)
        self.parent = parent
        self.data_tab_glayout = QtWidgets.QGridLayout(self)
        self.figure = Figure(facecolor='white')
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setParent(parent)

        self.canvas.setFocusPolicy(QtCore.Qt.ClickFocus)
        self.canvas.setFocus()

        self.canvas_vlayout = QtWidgets.QVBoxLayout(self.canvas)
        self.data_tab_glayout.addWidget(self.canvas, 0, 0, 2, 1)

        # self.canvas.mpl_connect('key_press_event', self.on_key_press)
        self.figure.canvas.mpl_connect('button_press_event', self.onclick)
        self.axe = self.canvas.figure.add_subplot(111)
        self.canvas.figure.subplots_adjust(left=0.025, top=0.965, bottom=0.040, right=0.975)
        # add the tab to the parent
        parent.addTab(self, "")

        # set text name
        parent.setTabText(parent.indexOf(self), title)

    def onclick(self,event):
        print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
              ('double' if event.dblclick else 'single', event.button,
               event.x, event.y, event.xdata, event.ydata))
        x = event.xdata
        y = event.ydata
        for n in self.graph_instance.node_artists :
            node = self.graph_instance.node_artists [n]
            dist = ((x-node.xy[0])**2 + (y-node.xy[1])**2)**0.5
            if dist < node.radius:
                print(node)



    def on_key_press(self, event):
        print("you press", event.key)


class MyApp(QtWidgets.QMainWindow):

    def __init__(self, parent=None):
        QtWidgets.QMainWindow.__init__(self, parent)

        self.showMaximized()

        self.centralwidget = QtWidgets.QWidget(self)
        self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
        self.core_tab = QtWidgets.QTabWidget(self.centralwidget)
        self.verticalLayout.addWidget(self.core_tab)
        self.add_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.add_tab_btn)
        self.refresh_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.refresh_tab_btn)
        self.setCentralWidget(self.centralwidget)

        self.add_tab_btn.setText("Add Tab")
        self.refresh_tab_btn.setText("Refresh Tabs")

        self.core_tab.setEnabled(True)
        self.core_tab.setTabShape(QtWidgets.QTabWidget.Rounded)
        self.core_tab.setElideMode(QtCore.Qt.ElideNone)
        self.core_tab.setDocumentMode(False)
        self.core_tab.setTabsClosable(True)
        self.core_tab.setMovable(True)
        self.core_tab.setTabBarAutoHide(False)

        self.tab_counter = 0

        self.random_tabs = [("a", ["b", "c"]),
                            ("d", ["e", "f", "g"]),
                            ("h", ["i", "j", "k", "l"]),
                            ("m", ["n"]),
                            ("o", ["p", "q"]),
                            ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])]

        self.add_tab_btn.clicked.connect(self.openRandomTab)
        self.refresh_tab_btn.clicked.connect(self.refreshAllTabs)

    def openRandomTab(self):

        tab = data_tab(self.core_tab, "test " + str(self.tab_counter))
        self._drawDataGraph(self.tab_counter % len(self.random_tabs), tab)
        self.tab_counter += 1

        self.core_tab.setCurrentIndex(self.core_tab.indexOf(tab))


    def _drawDataGraph(self, tabNb, dataWidget):
        dataWidget.axe.cla()

        # 1. draw graph
        producer = self.random_tabs[tabNb][0]
        consumers = self.random_tabs[tabNb][1]

        color_map = []
        DG = nx.DiGraph()
        for i, cons in enumerate(consumers):
            DG.add_edge(producer, cons, label=f"edge-{i}")

        node_color = dict()
        for node in  DG:
            if node in producer:
                node_color[node] = "#DCE46F"
            else:
                node_color[node] = "#6FA2E4"
        pos = nx.shell_layout( DG)
        pos[producer] = pos[producer] + np.array([0.2, 0])
        labels = nx.get_edge_attributes( DG, 'label')

        graph_instance = InteractiveGraph( DG, node_layout=pos, edge_layout='curved', origin=(-1, -1), scale=(2, 2),
                                          node_color=node_color, node_size=8.,
                                          node_labels=True, node_label_fontdict=dict(size=10),
                                          edge_labels=labels, edge_label_fontdict=dict(size=10), ax=dataWidget.axe,pickable=True
                                          )
        dataWidget.DG = DG
        dataWidget.graph_instance =graph_instance
        # dataWidget.canvas.show(pickable=True )
        dataWidget.canvas.draw()


    def refreshAllTabs(self):

        # loop through all pages and associated to get
        for tab_index in range(self.core_tab.count()):
            data_tab_widget = self.core_tab.widget(tab_index)

            # draw graph
            self._drawDataGraph(tab_index % len(self.random_tabs), data_tab_widget)




sys.argv = ['']
app = QtWidgets.QApplication(sys.argv)
main_app = MyApp()
main_app.show()
app.exec_()
Sign up to request clarification or add additional context in comments.

4 Comments

OK, I tried your solution, it did not really work at first, but you did put me on the correct path! I had an exception because self.graph_instance did not exist. So I moved the GraphInstance call to data_tab class, and the instanciation needs to use "self". mpl_connect is not necessary.
I put it in the _drawDataGraph function with dataWidget.graph_instance =graph_instance
by clickable you meant just to move node on the graph?
Yep indeed, I did not need anything fancier
0

OK, following ymmx answer, I figured out that the instanciation of InteractiveGraph needs to not be local.

So here's the solution which seems to behave as wanted in my case:

import sys
from PyQt5 import QtCore, QtWidgets
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.pyplot import Figure
import networkx as nx
import numpy as np
from netgraph import InteractiveGraph


class data_tab(QtWidgets.QWidget):

    def __init__(self, parent, title):

        QtWidgets.QWidget.__init__(self, parent)

        self.data_tab_glayout = QtWidgets.QGridLayout(self)
        self.figure = Figure(figsize=(5, 3))
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setParent(parent)

        self.canvas.setFocusPolicy(QtCore.Qt.ClickFocus)
        self.canvas.setFocus()

        self.canvas_vlayout = QtWidgets.QVBoxLayout(self.canvas)
        self.data_tab_glayout.addWidget(self.canvas, 0, 0, 2, 1)

        #self.figure.canvas.mpl_connect('button_press_event', self.onclick)

        self.axe = self.canvas.figure.add_subplot(111)
        self.canvas.figure.subplots_adjust(left=0.025, top=0.965, bottom=0.040, right=0.975)
        # add the tab to the parent
        parent.addTab(self, "")

        # set text name
        parent.setTabText(parent.indexOf(self), title)

    def createInteractiveGraph(self, DG, pos, node_color, labels):

        self.graph_instance = InteractiveGraph(DG, node_layout=pos, edge_layout='curved', origin=(-1, -1), scale=(2, 2),
                                          node_color=node_color, node_size=8.,
                                          node_labels=True, node_label_fontdict=dict(size=10),
                                          edge_labels=labels, edge_label_fontdict=dict(size=10), ax=self.axe
                                          )

        self.canvas.draw()


class MyApp(QtWidgets.QMainWindow):

    def __init__(self, parent=None):
        QtWidgets.QMainWindow.__init__(self, parent)

        self.showMaximized()

        self.centralwidget = QtWidgets.QWidget(self)
        self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
        self.core_tab = QtWidgets.QTabWidget(self.centralwidget)
        self.verticalLayout.addWidget(self.core_tab)
        self.add_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.add_tab_btn)
        self.refresh_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.refresh_tab_btn)
        self.setCentralWidget(self.centralwidget)

        self.add_tab_btn.setText("Add Tab")
        self.refresh_tab_btn.setText("Refresh Tabs")

        self.core_tab.setEnabled(True)
        self.core_tab.setTabShape(QtWidgets.QTabWidget.Rounded)
        self.core_tab.setElideMode(QtCore.Qt.ElideNone)
        self.core_tab.setDocumentMode(False)
        self.core_tab.setTabsClosable(True)
        self.core_tab.setMovable(True)
        self.core_tab.setTabBarAutoHide(False)

        self.tab_counter = 0

        self.random_tabs = [("a", ["b", "c"]),
                            ("d", ["e", "f", "g"]),
                            ("h", ["i", "j", "k", "l"]),
                            ("m", ["n"]),
                            ("o", ["p", "q"]),
                            ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])]

        self.add_tab_btn.clicked.connect(self.openRandomTab)
        self.refresh_tab_btn.clicked.connect(self.refreshAllTabs)

    def openRandomTab(self):

        tab = data_tab(self.core_tab, "test " + str(self.tab_counter))
        self._drawDataGraph(self.tab_counter % len(self.random_tabs), tab)
        self.tab_counter += 1

        self.core_tab.setCurrentIndex(self.core_tab.indexOf(tab))


    def _drawDataGraph(self, tabNb, dataWidget):
        dataWidget.axe.cla()

        # 1. draw graph
        producer = self.random_tabs[tabNb][0]
        consumers = self.random_tabs[tabNb][1]

        color_map = []
        DG = nx.DiGraph()
        for i, cons in enumerate(consumers):
            DG.add_edge(producer, cons, label=f"edge-{i}")

        node_color = dict()
        for node in DG:
            if node in producer:
                node_color[node] = "#DCE46F"
            else:
                node_color[node] = "#6FA2E4"
        pos = nx.shell_layout(DG)
        pos[producer] = pos[producer] + np.array([0.2, 0])
        labels = nx.get_edge_attributes(DG, 'label')

        dataWidget.createInteractiveGraph(DG, pos, node_color, labels)


    def refreshAllTabs(self):

        # loop through all pages and associated to get
        for tab_index in range(self.core_tab.count()):
            data_tab_widget = self.core_tab.widget(tab_index)

            # draw graph
            self._drawDataGraph(tab_index % len(self.random_tabs), data_tab_widget)




sys.argv = ['']
app = QtWidgets.QApplication(sys.argv)
main_app = MyApp()
main_app.show()
app.exec_()

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.