0

I am currently trying to create an interactive web-app to display and work with multi-graphs, ideally with streamlit. I have a first working version based on matplotlib (inspired by the netgraph package), but I cannot make it work with streamlit.

I include a first working version below to illustrate the functionalities. Key is the following (which I struggle to make to work with streamlit): edges between nodes are defined through "handles" (an edge is just a path going through these handles). All handles can be dragged (modifying the shape of the edge), and the first and last handle are projected onto the boundary of its respective node.

Is there a framework compatible with streamlit which gives a similar amount of control and degree of interactivity to draw network graphs?

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.patches import PathPatch



def rectangle(x,y, width, height):
    verts = [(x - width/2 , y - width/2), (x + width/2 , y - width/2),(x + width/2 , y + height/2),(x - width/2 , y + height/2),(x - width/2 , y + width/2)]
    codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]
    return Path(verts, codes)

def line(handles):
     verts = [[h.xy[0], h.xy[1]] for h in handles]
     codes = [Path.MOVETO] + [Path.LINETO for _ in range(len(verts)-1)]
     return Path(verts, codes)


class PathPatchDataUnits(PathPatch):
    """PathPatch in which the linewidth is also given in data units.

    Stolen from Netgraph resp. https://stackoverflow.com/a/42972469/2912349
    """
    def __init__(self, *args, **kwargs):
        _lw_data = kwargs.pop("linewidth", 1)
        super().__init__(*args, **kwargs)
        self._lw_data = _lw_data

    def _get_lw(self):
        if self.axes is not None:
            ppd = 72./self.axes.figure.dpi
            trans = self.axes.transData.transform
            return ((trans((self._lw_data, self._lw_data))-trans((0, 0)))*ppd)[0]
            # return ((trans((self._lw_data, self._lw_data))-trans((0, 0)))*ppd)[1]
        else:
            return 1

    def _set_lw(self, lw):
        self._lw_data = lw

    _linewidth = property(_get_lw, _set_lw)


class NodeArtist(PathPatchDataUnits):

    def __init__(self, xy = [0,0], side_lengths = [2, 2], line_width = 0.1, name = '', **kwargs):
        self._xy = xy
        self.side_lengths = side_lengths
        self.name = name        
        super().__init__(path=self._build_path(), facecolor = 'none', picker=True, **kwargs)
        self._set_lw(line_width)
        self.dependents = set()
        self.zorder = 1


    def _build_path(self):
        return rectangle(self.xy[0], self.xy[1], self.side_lengths[0], self.side_lengths[1])

    @property
    def xy(self): return self._xy

    @xy.setter
    def xy(self, vals):
        self._xy = np.array(vals, float)
        self.set_path(self._build_path())
        self._update_dependents()
        self.stale = True

    def _update_dependents(self):
        for handle in self.dependents:
            handle.move(0,0)

    def move(self, dx, dy):
        direction = np.array([dx, dy], float)
        #print(self.xy)
        for dep in self.dependents:
            dep.xy = dep.xy + direction
        self.xy += direction #the order is important: when we update .xy, we automatically project all its dependents onto it before doing anything else.
            
    
    def add_dependent(self, handle):
        self.dependents.add(handle) #we add the handle to the dependents
        handle.node = self          #for the handle we indicate the node it belongs to
        self._update_dependents()

    def projection_onto_boundary(self, pt):
        x_sign = 1 if pt[0]>=self.xy[0] else -1
        y_sign = 1 if pt[1]>=self.xy[1] else -1
        a = min(x_sign*(pt[0]-self.xy[0]), self.side_lengths[0]/2)
        b = min(y_sign*(pt[1]-self.xy[1]), self.side_lengths[1]/2)
        if (self.side_lengths[0]/2-a)<=(self.side_lengths[1]/2-b):
            a = self.side_lengths[0]/2
        if (self.side_lengths[1]/2-b)<=(self.side_lengths[0]/2-a):
            b = self.side_lengths[1]/2
        return np.array([x_sign*a + self.xy[0], y_sign*b + self.xy[1]])


class HandleArtist(PathPatchDataUnits):
    def __init__(self, xy, handle_size=0.25, node = None, **line_kw):
        self._xy = xy
        self.handle_size = handle_size
        super().__init__(path=self._build_path(), edgecolor = 'black', facecolor = 'red', picker = True, **line_kw)
        self._set_lw(0.02)
        self._edges = set() #stores the edges the Handle is part of
        self._node = None #if the handle is an endpoint it stores the corresponding node
        if node is not None:
            self.node = node
        self.zorder = 3

    def _build_path(self):
        return rectangle(self._xy[0], self._xy[1], self.handle_size, self.handle_size)
    
    @property
    def xy(self): return self._xy

    @xy.setter
    def xy(self, vals):
        self._xy = np.array(vals, float)
        self.set_path(self._build_path()) #rebuilds the handle
        for e in self._edges: #updates the edges the handle is part of
           e.refresh_edge()
        self.stale = True
    
    @property
    def node(self): return self._node

    @node.setter
    def node(self, n):
        if n is not None:
            self._node = n
            n.dependents.add(self)  #adds the handle to the dependents of the node
            self.move(0,0)          #projects onto the boundary of the node
        else:
            self._node = None

    def move(self, dx, dy):
        direction = np.array([dx, dy], float)
        if self._node == None:
            self.xy += direction
        else:
            self.xy = self._node.projection_onto_boundary(self.xy + direction)

class EdgeArtist(PathPatchDataUnits):
    def __init__(self, handles, **line_kw):
        self.handles = handles
        self._path = line(self.handles)
        super().__init__(path=self._path, facecolor = "none", picker = True, **line_kw)
        self._set_lw(0.1)
        self.zorder = 2
        for handle in self.handles: 
            handle._edges.add(self)
    
    def refresh_edge(self):
        self.set_path(line(self.handles))
        self.stale = True

    def move(self, dx, dy):
        return #an edge cannot be moved

    
class Graph():
    def __init__(self, nodes, edges, ax):
        self.node_artists = self._add_nodes(nodes)
        #self.edge_artists, self.handle_artists = self._add_edges(edges)
        self.edge_artists = []
        self.handle_artists = []
        self._add_edges(edges)
        self.ax = ax
        self.ax.figure.canvas.draw()
    
    def _add_nodes(self, nodes):
        dic = {}
        for n in nodes:
            dic[n] = NodeArtist(nodes[n], name = n)
            ax.add_patch(dic[n])
        return dic
    
    def _add_edge_and_corresponding_handles(self, edge):

        e_handles = [] #stores all the HandleArtists corresponding to the edge
        (endpoint1, endpoint2) = edge['endpoints']
        if len(edge['handles']) <= 1:
            handles_xy = [self.node_artists[endpoint1].xy, self.node_artists[endpoint2].xy]
        
        if len(edge['handles']) >= 2:
            handles_xy = edge['handles']
        
        e_handles.append(HandleArtist(handles_xy[0], node = self.node_artists[endpoint1]))
        
        for i in range(1, len(edge['handles'])-1):
            e_handles.append(HandleArtist(handles_xy[i]))
        e_handles.append(HandleArtist(handles_xy[-1], node = self.node_artists[endpoint2]))

        edge = EdgeArtist(e_handles)
        ax.add_patch(edge)
        for handle in e_handles:
            ax.add_patch(handle)
        
        return edge, e_handles
    
    def _add_edges(self, edges):
        edge_artists = []
        handle_artists = []
        for edge in edges:
            e,h = self._add_edge_and_corresponding_handles(edge)
            self.edge_artists = self.edge_artists + [e]
            self.handle_artists = self.handle_artists + h
        return edge_artists, handle_artists
    


class DraggableGraph(Graph):
    def __init__(self, nodes, edges, ax):
        super().__init__(nodes, edges, ax)
        self.active = None
        self.offset = [0, 0]

        self.ax.figure.canvas.mpl_connect('pick_event', self.on_press)
        self.ax.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
        self.ax.figure.canvas.mpl_connect('button_release_event', self.on_release)


    def on_press(self, event):
        art = event.artist
        if not hasattr(art, "_xy"):   # only drag nodes/handles with .xy attribute
            print('')
        else:
            me = event.mouseevent
            self.active = art
            self.offset = [me.xdata, me.ydata]

    def on_move(self, event):
        if self.active is None or event.inaxes != self.ax:
            return
        if event.xdata is None:  # mouse left axes
            return
        self.active.move(event.xdata - self.offset[0], event.ydata - self.offset[1])
        self.offset = [event.xdata, event.ydata]

        ax.figure.canvas.draw_idle()
    
    def on_release(self, event):
        if self.active == None:
            return
        self.active = None

# ----------------- demo -----------------
if __name__ == "__main__":
    fig, ax = plt.subplots()
    ax.set_aspect("equal")
    ax.set_xlim(0, 10); ax.set_ylim(0, 10)

    nodes = {'A': [5,1], 'B': [4,7], 'C': [8,5], 'D': [2,2]}

    edges = [{'endpoints': ('ZRH', 'AAR'),
              'handles': [(1,2), (3,3), (5,4), (10,7)],
              }, 
              {'endpoints': ('ZRH', 'AAR'),
              'handles': [(1,5), (5,5), (10,1)],
              }]

    g = DraggableGraph(nodes, edges, ax)

    print(g.node_artists['ZRH'].xy)
    g.node_artists['AAR'].move(2,-5)

    #g.node_artists['ZRH'].move(1,1)
    g.handle_artists[2].move(-5,-1)

1

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.