I'd like to plot a convergence process of the MLE algorithm with the plotly library.
Requirements:
- the points have to be colored colored in the colors of the clusters, and change accordingly each iteration
- the centroids of the clusters should be plotted on each iteration.
A plot of a single iteration may be produced by Code 1, with the desired output shown in Figure 1:
Code 1
import plotly.graph_objects as go
import numpy as np
A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']
fig = go.Figure()
for i in range(5):
fig.add_trace(
go.Scatter(
x=A[i:i+3][:, 0],
y=A[i:i+3][:, 1],
mode='markers',
name=f'cluster {i+1}',
marker_color=colors[i]
)
)
for c in clusters:
fig.add_trace(
go.Scatter(
x=[centroids[c-1][0]],
y=[centroids[c-1][1]],
name=f'centroid of cluster {c}',
mode='markers',
marker_color=colors[c-1],
marker_symbol='x'
)
)
fig.show()
Figure 1
I've seen this tutorial, but it seems that you can plot only a single trace in a graph_objects.Frame(), and Code 2 represents a simple example for producing an animated scatter plot of all the points, where each frame plots points from different cluster and the centroids:
Code 2
import plotly.graph_objects as go
import numpy as np
A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']
fig = go.Figure(
data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 1', marker_color=colors[0])],
layout=go.Layout(
xaxis=dict(range=[-10, 10], autorange=False),
yaxis=dict(range=[-10, 10], autorange=False),
title="Start Title",
updatemenus=[dict(
type="buttons",
buttons=[dict(label="Play",
method="animate",
args=[None])])]
),
frames=[go.Frame(data=[go.Scatter(x=A[:3][:,0], y=A[:3][:,1], mode='markers', name='cluster 2', marker_color=colors[1])]),
go.Frame(data=[go.Scatter(x=A[3:5][:,0], y=A[3:5][:,1], mode='markers', name='cluster 3', marker_color=colors[2])]),
go.Frame(data=[go.Scatter(x=A[5:8][:,0], y=A[5:8][:,1], mode='markers', name='cluster 4', marker_color=colors[3])]),
go.Frame(data=[go.Scatter(x=A[8:][:,0], y=A[8:][:,1], mode='markers', name='cluster 5', marker_color=colors[4])]),
go.Frame(data=[go.Scatter(x=[centroids[0][0]], y=[centroids[0][1]], mode='markers', name='centroid of cluster 1', marker_color=colors[0], marker_symbol='x')]),
go.Frame(data=[go.Scatter(x=[centroids[1][0]], y=[centroids[1][1]], mode='markers', name='centroid of cluster 2', marker_color=colors[1], marker_symbol='x')]),
go.Frame(data=[go.Scatter(x=[centroids[2][0]], y=[centroids[2][1]], mode='markers', name='centroid of cluster 3', marker_color=colors[2], marker_symbol='x')]),
go.Frame(data=[go.Scatter(x=[centroids[3][0]], y=[centroids[3][1]], mode='markers', name='centroid of cluster 4', marker_color=colors[3], marker_symbol='x')]),
go.Frame(data=[go.Scatter(x=[centroids[4][0]], y=[centroids[4][1]], mode='markers', name='centroid of cluster 5', marker_color=colors[4], marker_symbol='x')])]
)
fig.show()
Why does Code 2 does not fit my needs:
- I need to plot all the frames produced by
Code 2in a single frame each iteration of the algorithm (i.e. each frame of the desired solution will look likeFigure 1)
What I have tried:
- I have tried producing a
graph_objects.Figure(), and adding it to agraph_objects.Frame()as shown inCode 3, but have gottenError 1.
Code 3:
import plotly.graph_objects as go
import numpy as np
A = np.random.randn(30).reshape((15, 2))
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = [1, 2, 3, 4, 5]
colors = ['red', 'green', 'blue', 'yellow', 'magenta']
fig = go.Figure()
for i in range(5):
fig.add_trace(
go.Scatter(
x=A[i:i+3][:, 0],
y=A[i:i+3][:, 1],
mode='markers',
name=f'cluster {i+1}',
marker_color=colors[i]
)
)
for c in clusters:
fig.add_trace(
go.Scatter(
x=[centroids[c-1][0]],
y=[centroids[c-1][1]],
name=f'centroid of cluster {c}',
mode='markers',
marker_color=colors[c-1],
marker_symbol='x'
)
)
animated_fig = go.Figure(
data=[go.Scatter(x=A[:3][:, 0], y=A[:3][:, 1], mode='markers', name=f'cluster 0', marker_color=colors[0])],
layout=go.Layout(
xaxis=dict(range=[-10, 10], autorange=False),
yaxis=dict(range=[-10, 10], autorange=False),
title="Start Title",
updatemenus=[dict(
type="buttons",
buttons=[dict(label="Play",
method="animate",
args=[None])])]
),
frames=[go.Frame(data=[fig])]
)
animated_fig.show()
Error 1:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-681-11264f38e6f7> in <module>
43 args=[None])])]
44 ),
---> 45 frames=[go.Frame(data=[fig])]
46 )
47
~\Anaconda3\lib\site-packages\plotly\graph_objs\_frame.py in __init__(self, arg, baseframe, data, group, layout, name, traces, **kwargs)
241 _v = data if data is not None else _v
242 if _v is not None:
--> 243 self["data"] = _v
244 _v = arg.pop("group", None)
245 _v = group if group is not None else _v
~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in __setitem__(self, prop, value)
3973 # ### Handle compound array property ###
3974 elif isinstance(validator, (CompoundArrayValidator, BaseDataValidator)):
-> 3975 self._set_array_prop(prop, value)
3976
3977 # ### Handle simple property ###
~\Anaconda3\lib\site-packages\plotly\basedatatypes.py in _set_array_prop(self, prop, val)
4428 # ------------
4429 validator = self._get_validator(prop)
-> 4430 val = validator.validate_coerce(val, skip_invalid=self._skip_invalid)
4431
4432 # Save deep copies of current and new states
~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in validate_coerce(self, v, skip_invalid, _validate)
2671
2672 if invalid_els:
-> 2673 self.raise_invalid_elements(invalid_els)
2674
2675 v = to_scalar_or_list(res)
~\Anaconda3\lib\site-packages\_plotly_utils\basevalidators.py in raise_invalid_elements(self, invalid_els)
298 pname=self.parent_name,
299 invalid=invalid_els[:10],
--> 300 valid_clr_desc=self.description(),
301 )
302 )
ValueError:
Invalid element(s) received for the 'data' property of frame
Invalid elements include: [Figure({
'data': [{'marker': {'color': 'red'},
'mode': 'markers',
'name': 'cluster 1',
'type': 'scatter',
'x': array([-1.30634452, -1.73005459, 0.58746435]),
'y': array([ 0.15388112, 0.47452796, -1.86354483])},
{'marker': {'color': 'green'},
'mode': 'markers',
'name': 'cluster 2',
'type': 'scatter',
'x': array([-1.73005459, 0.58746435, -0.27492892]),
'y': array([ 0.47452796, -1.86354483, -0.20329897])},
{'marker': {'color': 'blue'},
'mode': 'markers',
'name': 'cluster 3',
'type': 'scatter',
'x': array([ 0.58746435, -0.27492892, 0.21002816]),
'y': array([-1.86354483, -0.20329897, 1.99487636])},
{'marker': {'color': 'yellow'},
'mode': 'markers',
'name': 'cluster 4',
'type': 'scatter',
'x': array([-0.27492892, 0.21002816, -0.0148647 ]),
'y': array([-0.20329897, 1.99487636, 0.73484184])},
{'marker': {'color': 'magenta'},
'mode': 'markers',
'name': 'cluster 5',
'type': 'scatter',
'x': array([ 0.21002816, -0.0148647 , 1.13589386]),
'y': array([1.99487636, 0.73484184, 2.08810809])},
{'marker': {'color': 'red', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 1',
'type': 'scatter',
'x': [9],
'y': [6]},
{'marker': {'color': 'green', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 2',
'type': 'scatter',
'x': [0],
'y': [5]},
{'marker': {'color': 'blue', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 3',
'type': 'scatter',
'x': [8],
'y': [6]},
{'marker': {'color': 'yellow', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 4',
'type': 'scatter',
'x': [7],
'y': [1]},
{'marker': {'color': 'magenta', 'symbol': 'x'},
'mode': 'markers',
'name': 'centroid of cluster 5',
'type': 'scatter',
'x': [6],
'y': [2]}],
'layout': {'template': '...'}
})]
The 'data' property is a tuple of trace instances
that may be specified as:
- A list or tuple of trace instances
(e.g. [Scatter(...), Bar(...)])
- A single trace instance
(e.g. Scatter(...), Bar(...), etc.)
- A list or tuple of dicts of string/value properties where:
- The 'type' property specifies the trace type
One of: ['area', 'bar', 'barpolar', 'box',
'candlestick', 'carpet', 'choropleth',
'choroplethmapbox', 'cone', 'contour',
'contourcarpet', 'densitymapbox', 'funnel',
'funnelarea', 'heatmap', 'heatmapgl',
'histogram', 'histogram2d',
'histogram2dcontour', 'image', 'indicator',
'isosurface', 'mesh3d', 'ohlc', 'parcats',
'parcoords', 'pie', 'pointcloud', 'sankey',
'scatter', 'scatter3d', 'scattercarpet',
'scattergeo', 'scattergl', 'scattermapbox',
'scatterpolar', 'scatterpolargl',
'scatterternary', 'splom', 'streamtube',
'sunburst', 'surface', 'table', 'treemap',
'violin', 'volume', 'waterfall']
- All remaining properties are passed to the constructor of
the specified trace type
(e.g. [{'type': 'scatter', ...}, {'type': 'bar, ...}])
- I've succeeded to get all the points present in each frame with the use of
plotly.expressmodule, as shown inCode 3, but the only thing that is missing there is for the centroids to be marked asxs.
Code 3:
import plotly.express as px
import numpy as np
import pandas as pd
A = np.random.randn(200).reshape((100, 2))
iteration = np.array([1, 2, 3, 4, 5]).repeat(20)
centroids = np.random.randint(10, size=10).reshape((5, 2))
clusters = np.random.randint(1, 6, size=100)
colors = ['red', 'green', 'blue', 'yellow', 'magenta']
df = pd.DataFrame(dict(x1=A[:, 0], x2=A[:, 1], type='point', cluster=pd.Series(clusters, dtype='str'), iteration=iteration))
centroid_df = pd.DataFrame(dict(x1=centroids[:, 0], x2=centroids[:, 1], type='centroid', cluster=[1, 2, 3, 4, 5], iteration=[1, 2, 3, 4, 5]))
df = df.append(centroid_df, ignore_index=True)
px.scatter(df, x="x1", y="x2", animation_frame="iteration", color="cluster", hover_name="cluster", range_x=[-10,10], range_y=[-10,10])
I'd appreciate any help for achieving the desired result. Thanks.

