I am currently trying to create an interactive web-app to display and work with multi-graphs, ideally with streamlit. I have a first working version based on matplotlib (inspired by the netgraph package), but I cannot make it work with streamlit.
I include a first working version below to illustrate the functionalities. Key is the following (which I struggle to make to work with streamlit): edges between nodes are defined through "handles" (an edge is just a path going through these handles). All handles can be dragged (modifying the shape of the edge), and the first and last handle are projected onto the boundary of its respective node.
Is there a framework compatible with streamlit which gives a similar amount of control and degree of interactivity to draw network graphs?
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.patches import PathPatch
def rectangle(x,y, width, height):
verts = [(x - width/2 , y - width/2), (x + width/2 , y - width/2),(x + width/2 , y + height/2),(x - width/2 , y + height/2),(x - width/2 , y + width/2)]
codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]
return Path(verts, codes)
def line(handles):
verts = [[h.xy[0], h.xy[1]] for h in handles]
codes = [Path.MOVETO] + [Path.LINETO for _ in range(len(verts)-1)]
return Path(verts, codes)
class PathPatchDataUnits(PathPatch):
"""PathPatch in which the linewidth is also given in data units.
Stolen from Netgraph resp. https://stackoverflow.com/a/42972469/2912349
"""
def __init__(self, *args, **kwargs):
_lw_data = kwargs.pop("linewidth", 1)
super().__init__(*args, **kwargs)
self._lw_data = _lw_data
def _get_lw(self):
if self.axes is not None:
ppd = 72./self.axes.figure.dpi
trans = self.axes.transData.transform
return ((trans((self._lw_data, self._lw_data))-trans((0, 0)))*ppd)[0]
# return ((trans((self._lw_data, self._lw_data))-trans((0, 0)))*ppd)[1]
else:
return 1
def _set_lw(self, lw):
self._lw_data = lw
_linewidth = property(_get_lw, _set_lw)
class NodeArtist(PathPatchDataUnits):
def __init__(self, xy = [0,0], side_lengths = [2, 2], line_width = 0.1, name = '', **kwargs):
self._xy = xy
self.side_lengths = side_lengths
self.name = name
super().__init__(path=self._build_path(), facecolor = 'none', picker=True, **kwargs)
self._set_lw(line_width)
self.dependents = set()
self.zorder = 1
def _build_path(self):
return rectangle(self.xy[0], self.xy[1], self.side_lengths[0], self.side_lengths[1])
@property
def xy(self): return self._xy
@xy.setter
def xy(self, vals):
self._xy = np.array(vals, float)
self.set_path(self._build_path())
self._update_dependents()
self.stale = True
def _update_dependents(self):
for handle in self.dependents:
handle.move(0,0)
def move(self, dx, dy):
direction = np.array([dx, dy], float)
#print(self.xy)
for dep in self.dependents:
dep.xy = dep.xy + direction
self.xy += direction #the order is important: when we update .xy, we automatically project all its dependents onto it before doing anything else.
def add_dependent(self, handle):
self.dependents.add(handle) #we add the handle to the dependents
handle.node = self #for the handle we indicate the node it belongs to
self._update_dependents()
def projection_onto_boundary(self, pt):
x_sign = 1 if pt[0]>=self.xy[0] else -1
y_sign = 1 if pt[1]>=self.xy[1] else -1
a = min(x_sign*(pt[0]-self.xy[0]), self.side_lengths[0]/2)
b = min(y_sign*(pt[1]-self.xy[1]), self.side_lengths[1]/2)
if (self.side_lengths[0]/2-a)<=(self.side_lengths[1]/2-b):
a = self.side_lengths[0]/2
if (self.side_lengths[1]/2-b)<=(self.side_lengths[0]/2-a):
b = self.side_lengths[1]/2
return np.array([x_sign*a + self.xy[0], y_sign*b + self.xy[1]])
class HandleArtist(PathPatchDataUnits):
def __init__(self, xy, handle_size=0.25, node = None, **line_kw):
self._xy = xy
self.handle_size = handle_size
super().__init__(path=self._build_path(), edgecolor = 'black', facecolor = 'red', picker = True, **line_kw)
self._set_lw(0.02)
self._edges = set() #stores the edges the Handle is part of
self._node = None #if the handle is an endpoint it stores the corresponding node
if node is not None:
self.node = node
self.zorder = 3
def _build_path(self):
return rectangle(self._xy[0], self._xy[1], self.handle_size, self.handle_size)
@property
def xy(self): return self._xy
@xy.setter
def xy(self, vals):
self._xy = np.array(vals, float)
self.set_path(self._build_path()) #rebuilds the handle
for e in self._edges: #updates the edges the handle is part of
e.refresh_edge()
self.stale = True
@property
def node(self): return self._node
@node.setter
def node(self, n):
if n is not None:
self._node = n
n.dependents.add(self) #adds the handle to the dependents of the node
self.move(0,0) #projects onto the boundary of the node
else:
self._node = None
def move(self, dx, dy):
direction = np.array([dx, dy], float)
if self._node == None:
self.xy += direction
else:
self.xy = self._node.projection_onto_boundary(self.xy + direction)
class EdgeArtist(PathPatchDataUnits):
def __init__(self, handles, **line_kw):
self.handles = handles
self._path = line(self.handles)
super().__init__(path=self._path, facecolor = "none", picker = True, **line_kw)
self._set_lw(0.1)
self.zorder = 2
for handle in self.handles:
handle._edges.add(self)
def refresh_edge(self):
self.set_path(line(self.handles))
self.stale = True
def move(self, dx, dy):
return #an edge cannot be moved
class Graph():
def __init__(self, nodes, edges, ax):
self.node_artists = self._add_nodes(nodes)
#self.edge_artists, self.handle_artists = self._add_edges(edges)
self.edge_artists = []
self.handle_artists = []
self._add_edges(edges)
self.ax = ax
self.ax.figure.canvas.draw()
def _add_nodes(self, nodes):
dic = {}
for n in nodes:
dic[n] = NodeArtist(nodes[n], name = n)
ax.add_patch(dic[n])
return dic
def _add_edge_and_corresponding_handles(self, edge):
e_handles = [] #stores all the HandleArtists corresponding to the edge
(endpoint1, endpoint2) = edge['endpoints']
if len(edge['handles']) <= 1:
handles_xy = [self.node_artists[endpoint1].xy, self.node_artists[endpoint2].xy]
if len(edge['handles']) >= 2:
handles_xy = edge['handles']
e_handles.append(HandleArtist(handles_xy[0], node = self.node_artists[endpoint1]))
for i in range(1, len(edge['handles'])-1):
e_handles.append(HandleArtist(handles_xy[i]))
e_handles.append(HandleArtist(handles_xy[-1], node = self.node_artists[endpoint2]))
edge = EdgeArtist(e_handles)
ax.add_patch(edge)
for handle in e_handles:
ax.add_patch(handle)
return edge, e_handles
def _add_edges(self, edges):
edge_artists = []
handle_artists = []
for edge in edges:
e,h = self._add_edge_and_corresponding_handles(edge)
self.edge_artists = self.edge_artists + [e]
self.handle_artists = self.handle_artists + h
return edge_artists, handle_artists
class DraggableGraph(Graph):
def __init__(self, nodes, edges, ax):
super().__init__(nodes, edges, ax)
self.active = None
self.offset = [0, 0]
self.ax.figure.canvas.mpl_connect('pick_event', self.on_press)
self.ax.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
self.ax.figure.canvas.mpl_connect('button_release_event', self.on_release)
def on_press(self, event):
art = event.artist
if not hasattr(art, "_xy"): # only drag nodes/handles with .xy attribute
print('')
else:
me = event.mouseevent
self.active = art
self.offset = [me.xdata, me.ydata]
def on_move(self, event):
if self.active is None or event.inaxes != self.ax:
return
if event.xdata is None: # mouse left axes
return
self.active.move(event.xdata - self.offset[0], event.ydata - self.offset[1])
self.offset = [event.xdata, event.ydata]
ax.figure.canvas.draw_idle()
def on_release(self, event):
if self.active == None:
return
self.active = None
# ----------------- demo -----------------
if __name__ == "__main__":
fig, ax = plt.subplots()
ax.set_aspect("equal")
ax.set_xlim(0, 10); ax.set_ylim(0, 10)
nodes = {'A': [5,1], 'B': [4,7], 'C': [8,5], 'D': [2,2]}
edges = [{'endpoints': ('ZRH', 'AAR'),
'handles': [(1,2), (3,3), (5,4), (10,7)],
},
{'endpoints': ('ZRH', 'AAR'),
'handles': [(1,5), (5,5), (10,1)],
}]
g = DraggableGraph(nodes, edges, ax)
print(g.node_artists['ZRH'].xy)
g.node_artists['AAR'].move(2,-5)
#g.node_artists['ZRH'].move(1,1)
g.handle_artists[2].move(-5,-1)
st-link-analysisgithub.com/AlrasheedA/st-link-analysis orstreamlit-cytoscapejsgithub.com/andfanilo/streamlit-cytoscapejs (however this one no longer seems to be maintained)