I have a similar question to what has been answered before (matplotlib: make plots in functions and then add each to a single subplot figure). However, I want to have more advanced plots. I'm using this function for plotting (taken from https://towardsdatascience.com/an-introduction-to-bayesian-inference-in-pystan-c27078e58d53):
def plot_trace(param, param_name='parameter', ax=None, **kwargs):
"""Plot the trace and posterior of a parameter."""
# Summary statistics
mean = np.mean(param)
median = np.median(param)
cred_min, cred_max = np.percentile(param, 2.5), np.percentile(param, 97.5)
# Plotting
#ax = ax or plt.gca()
plt.subplot(2,1,1)
plt.plot(param)
plt.xlabel('samples')
plt.ylabel(param_name)
plt.axhline(mean, color='r', lw=2, linestyle='--')
plt.axhline(median, color='c', lw=2, linestyle='--')
plt.axhline(cred_min, linestyle=':', color='k', alpha=0.2)
plt.axhline(cred_max, linestyle=':', color='k', alpha=0.2)
plt.title('Trace and Posterior Distribution for {}'.format(param_name))
plt.subplot(2,1,2)
plt.hist(param, 30, density=True); sns.kdeplot(param, shade=True)
plt.xlabel(param_name)
plt.ylabel('density')
plt.axvline(mean, color='r', lw=2, linestyle='--',label='mean')
plt.axvline(median, color='c', lw=2, linestyle='--',label='median')
plt.axvline(cred_min, linestyle=':', color='k', alpha=0.2, label='95% CI')
plt.axvline(cred_max, linestyle=':', color='k', alpha=0.2)
plt.gcf().tight_layout()
plt.legend()
and I would like to have those two subplots for different parameters. If I use this code, it simply overwrites the plot and ignores the ax parameter. Would you please help me how to make it work also for more than just 2 parameters?
params = [mu_a,mu_tau]
param_names = ['A mean','tau mean']
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16,16))
fig.subplots_adjust(hspace=0.5)
fig.suptitle('Convergence and distribution of parameters')
plot_trace(params[0], param_name=param_names[0], ax = ax1)
plot_trace(params[1], param_name=param_names[1], ax = ax2)
The expected result would be to have those plots next to each other How one subplot should look like (what the function makes). Thank you.
ax1.plot/ax2.plotlike in this exapmle: matplotlib.org/3.1.1/api/_as_gen/…