2

I am trying to animate multiple lines at once in matplotlib. To do this I am following the tutorial from the matplotlib.animation docs:

https://matplotlib.org/stable/api/animation_api.html

The idea in this tutorial is to create a line ln, = plt.plot([], []) and update the data of the line using ln.set_data in order to produce the animation. Whilst this all works fine when the line data is a 1 dimensional array (shape = (n,)) of n data points, I am having trouble when the line data is a 2 dimensional array (shape = (n,k)) of k lines to plot.

To be more precise, plt.plot accepts arrays as inputs, with each column corresponding to a new line to plot. Here is a simple example with 3 lines plotted with a single plt.plot call:

import matplotlib.pyplot as plt
import numpy as np


x = np.linspace(0, 2 * np.pi, 100).reshape(-1, 1)
x = np.concatenate([x] * 3, axis=1)

# generate 3 curves
y = np.copy(x)
y[:, 0] = np.cos(y[:, 0])
y[:, 1] = np.sin(y[:, 1] )
y[:, 2] = np.sin(y[:, 2] ) + np.cos(y[:, 2])

fig, ax = plt.subplots()
plt.plot(x,y)
plt.show()

plt.plot using arrays

However if I try to set the data using .set_data as required for generating animations I have a problem:

import matplotlib.pyplot as plt
import numpy as np


x = np.linspace(0, 2 * np.pi, 100).reshape(-1, 1)
x = np.concatenate([x] * 3, axis=1)

# generate 3 curves
y = np.copy(x)
y[:, 0] = np.cos(y[:, 0])
y[:, 1] = np.sin(y[:, 1] )
y[:, 2] = np.sin(y[:, 2] ) + np.cos(y[:, 2])

fig, ax = plt.subplots()
p, = plt.plot([], [], color='b')
p.set_data(x, y)
plt.show()

problem

Is there a way to set_data for 2 dimensional arrays? Whilst I am aware that I could just create three plots p1, p2, p3 and call set_data on each of them in a loop, my real data consists of 1000-10,000 lines to plot, and this makes the animation too slow.

Many thanks for any help.

2 Answers 2

7

An approach could be to create a list of Line2D objects and use set_data in a loop. Note that ax.plot() always returns a list of lines, even when only one line is plotted.

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

x = np.linspace(0, 2 * np.pi, 100)

# generate 10 curves
y = np.sin(x.reshape(-1, 1) + np.random.uniform(0, 2 * np.pi, (1, 10)))

fig, ax = plt.subplots()
ax.set(xlim=(0, 2 * np.pi), ylim=(-1.5, 1.5))
# lines = [ax.plot([], [], lw=2)[0] for _ in range(y.shape[1])]
lines = ax.plot(np.empty((0, y.shape[1])), np.empty((0, y.shape[1])), lw=2)

def animate(i):
    for line_k, y_k in zip(lines, y.T):
        line_k.set_data(x[:i], y_k[:i])
    return lines

anim = FuncAnimation(fig, animate, frames=x.size, interval=200, repeat=False)
plt.show()
Sign up to request clarification or add additional context in comments.

1 Comment

Great, that makes perfect sense. Thank you @JohanC
6

The array given by set_data() will be two one-dimensional arrays, so in this case three set_data() will be needed.

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

x = np.linspace(0, 2 * np.pi, 100).reshape(-1, 1)
x = np.concatenate([x] * 3, axis=1)

# generate 3 curves
y = np.copy(x)
y[:, 0] = np.cos(y[:, 0])
y[:, 1] = np.sin(y[:, 1] )
y[:, 2] = np.sin(y[:, 2] ) + np.cos(y[:, 2])

fig, ax = plt.subplots()
ax = plt.axes(xlim=(0,6), ylim=(-1.5, 1.5))
line1, = ax.plot([], [], lw=2)
line2, = ax.plot([], [], lw=2)
line3, = ax.plot([], [], lw=2)


def animate(i):
    line1.set_data(x[:i, 0], y[:i, 0])
    line2.set_data(x[:i, 1], y[:i, 1])
    line3.set_data(x[:i, 2], y[:i, 2])
    return line1,line2,line3

anim = FuncAnimation(fig, animate, frames=100, interval=200, repeat=False)
plt.show()

enter image description here

7 Comments

Animation is easy once you understand how it works. If my answer has helped you with your question, please accept this as the correct answer.
Hi, thanks for the response. As I said in the question: ""Whilst I am aware that I could just create three plots p1, p2, p3 and call set_data on each of them in a loop, my real data consists of 1000-10,000 lines to plot..."" I tried to make a list of all plots and then update their data in a loop during the update function. However, this does not work as the outputs of update need to be plots and not a list of plots.
Wouldn't it be sufficient to put all lines in a list? And then return that list? Something like lines = [ax.plot([], [], lw=2)[0] for _ in range(y.shape[1])], and use set_data in a loop.
@JohanC lines doesn't know how to write set_data() for a list of Line2D objects. After all, lines[0].set_data() would require 3 lines. Please advise.
I added an answer with my comment translated into code. It's just your approach extended to a loop. (Note that ax = plt.axes(...) creates a second subplot on top of the subplot created by plt.subplots(). You might want to remove one of them)
|

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.