3

(Adjusted to suggestions) I already have a function that performs some plot:

def plot_i(Y, ax = None):
    if ax == None:
        ax = plt.gca()
    fig = plt.figure()
    ax.plot(Y)
    plt.close(fig)
    return fig

And I wish to use this to plot in a grid for n arrays. Let's assume the grid is (n // 2, 2) for simplicity and that n is even. At the moment, I came up with this:

def multi_plot(Y_arr, function):
    n = len(Y_arr)
    fig, ax = plt.subplots(n // 2, 2)
    for i in range(n):
        # assign to one axis a call of the function = plot_i that draws a plot
    plt.close(fig)
    return fig

Unfortunately, what I get if I do something like:

# inside the loop
plot_i(Y[:, i], ax = ax[k,j])

Is correct but I need to close figures each time at the end, otherwise I keep on adding figures to plt. Is there any way I can avoid calling each time plt.close(fig)?

7
  • Not sure I understand what exactly you are looking for. But matplotlib has here a good example of how to distribute N subplots in an n x m grid. Commented Dec 27, 2021 at 20:17
  • @Mr.T I already have a way to distribute the plots, this one is easier and I will implement it. However, checking the link, what id done for each ax is ax.plot. I wish to call my personal plotting function for each ax. Doing so, I would be able to call this multi_plot for different functions. Commented Dec 27, 2021 at 20:20
  • 1
    Pass plot_i an ax parameter and switch to the object-oriented interface within it. Commented Dec 27, 2021 at 20:20
  • @BigBen I tried, but still do not understand how to do this, could you explain with an example? I tried setting ax = None as parameter in plot_i, which if None get the current axis with plt.gca(). In the multi_plot function, I then pass the axis parameter as the axis of interest, but yet it prints each single plot and does not return a fig object that contains all of them. Will edit the question accordingly. Commented Dec 27, 2021 at 20:29
  • 1
    @BigBen I actually solved the first issue, thank you for the clarification!!! Commented Dec 27, 2021 at 20:42

1 Answer 1

1

If I understand correctly, you are looking for something like this:

import numpy as np
import matplotlib.pyplot as plt

def plot_i(Y, ax=None):
    if ax == None:
        ax = plt.gca()
    ax.plot(Y)
    return

def multi_plot(Y_arr, function, n_cols=2):
    n = Y_arr.shape[1]
    fig, ax = plt.subplots(n // n_cols + (1 if n % n_cols else 0), n_cols)
    for i in range(n):
        # assign to one axis a call of the function = plot_i that draws a plot
        function(Y_arr[:, i], ax = ax[i//n_cols, i%n_cols])
    return fig

if __name__ == '__main__':
    x = np.linspace(0,12.6, 100)
    # let's create some fake data
    data = np.exp(-np.linspace(0,.5, 14)[np.newaxis, :] * x[:, np.newaxis]) * np.sin(x[:, np.newaxis])
    fig = multi_plot(data, plot_i, 3)

Be careful when using gca(): it will create a new figure if there is no figure active.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.