0

I have a similar question to what has been answered before (matplotlib: make plots in functions and then add each to a single subplot figure). However, I want to have more advanced plots. I'm using this function for plotting (taken from https://towardsdatascience.com/an-introduction-to-bayesian-inference-in-pystan-c27078e58d53):

def plot_trace(param, param_name='parameter', ax=None, **kwargs):
  """Plot the trace and posterior of a parameter."""

  # Summary statistics
    mean = np.mean(param)
    median = np.median(param)
    cred_min, cred_max = np.percentile(param, 2.5), np.percentile(param, 97.5)

  # Plotting
    #ax = ax or plt.gca()
    plt.subplot(2,1,1)
    plt.plot(param)
    plt.xlabel('samples')
    plt.ylabel(param_name)
    plt.axhline(mean, color='r', lw=2, linestyle='--')
    plt.axhline(median, color='c', lw=2, linestyle='--')
    plt.axhline(cred_min, linestyle=':', color='k', alpha=0.2)
    plt.axhline(cred_max, linestyle=':', color='k', alpha=0.2)
    plt.title('Trace and Posterior Distribution for {}'.format(param_name))

    plt.subplot(2,1,2)
    plt.hist(param, 30, density=True); sns.kdeplot(param, shade=True)
    plt.xlabel(param_name)
    plt.ylabel('density')
    plt.axvline(mean, color='r', lw=2, linestyle='--',label='mean')
    plt.axvline(median, color='c', lw=2, linestyle='--',label='median')
    plt.axvline(cred_min, linestyle=':', color='k', alpha=0.2, label='95% CI')
    plt.axvline(cred_max, linestyle=':', color='k', alpha=0.2)

    plt.gcf().tight_layout()
    plt.legend()

and I would like to have those two subplots for different parameters. If I use this code, it simply overwrites the plot and ignores the ax parameter. Would you please help me how to make it work also for more than just 2 parameters?

params = [mu_a,mu_tau]
param_names = ['A mean','tau mean']

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16,16))
fig.subplots_adjust(hspace=0.5)
fig.suptitle('Convergence and distribution of parameters')

plot_trace(params[0], param_name=param_names[0], ax = ax1)
plot_trace(params[1], param_name=param_names[1], ax = ax2)

The expected result would be to have those plots next to each other How one subplot should look like (what the function makes). Thank you.

2
  • You have to plot with ax1.plot / ax2.plot like in this exapmle: matplotlib.org/3.1.1/api/_as_gen/… Commented Dec 12, 2019 at 13:27
  • Yes, your function needs to take in two axes and specifically plot to those. Commented Dec 12, 2019 at 13:54

1 Answer 1

3

As suggested by both @Sebastian-R and @ImportanceOfBeingErnest, the problem is that you are creating the subplots inside the functions instead of using what is passed to the ax= keyword. In addition, if you want to plot on two different subplots, you need to pass two axes instances to your function

There are two ways to correct the problem:

  • the first solution would be to re-write your function to use the object-oriented API of matplotlib (this is the solution that I recommend, although it requires more work)

code:

def plot_trace(param, axs, param_name='parameter', **kwargs):
    """Plot the trace and posterior of a parameter."""
    # Summary statistics
    (...)
    # Plotting
    ax1 = axs[0]
    ax1.plot(param)
    ax1.set_xlabel('samples')  # notice the difference in notation here
    (...)

    ax2 = axs[1]
    ax2.hist(...); sns.kdeplot(..., ax=ax2)
    (...)
  • the second solution keeps your current code, with the exception that, since the plt.xxx() functions work on the current axes, you need to make the axes that you passed as an argument the current axes:

code:

def plot_trace(param, axs, param_name='parameter', **kwargs):
    """Plot the trace and posterior of a parameter."""
    # Summary statistics
    (...)
    # Plotting
    plt.sca(axs[0])
    plt.plot(param)
    plt.xlabel('samples')
    (...)

    plt.sca(axs[1])
    plt.hist(...)
    (...)

Then you need to create the number of subplots that you need at the end (so if you want to plot two parameters, you'll have to create 4 subplots) and call your function as you had wanted:

params = [mu_a,mu_tau]
param_names = ['A mean','tau mean']

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(16,16))
fig.subplots_adjust(hspace=0.5)
fig.suptitle('Convergence and distribution of parameters')

plot_trace(params[0], axs=[ax1, ax3], param_name=param_names[0])
plot_trace(params[1], axs=[ax2, ax4], param_name=param_names[1])
Sign up to request clarification or add additional context in comments.

3 Comments

You may want to reread the question, and my comment below it.
I see, i missed the part where there were two subplots generated by the function
Thank you all very much for comments and @Diziet Asahi for solutions, both work.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.