1

I am trying to create a figure which contains 9 subplots (3 x 3). X, and Y axis data is coming from the dataframe using groupby. Here is my code:

fig, axs = plt.subplots(3,3)
for index,cause in enumerate(cause_list):


    df[df['CAT']==cause].groupby('RYQ')['NO_CONSUMERS'].mean().axs[index].plot()
    axs[index].set_title(cause)



plt.show() 

However, it does not produce the desired output. In fact it returned the error. If I remove the axs[index]before plot() and put inside the plot() function like plot(ax=axs[index]) then it worked and produces nine subplot but did not display the data in it (as shown in the figure). enter image description here

Could anyone guide me where am I making the mistake?

6
  • What is your cause_list? Commented May 27, 2020 at 4:52
  • Cause_list filter the data frame. There are nine types of data present in the data frame. I want to show each type of data in each subplot Commented May 27, 2020 at 5:05
  • print(df[df['CAT']==cause].groupby('RYQ')['NO_CONSUMERS'].mean()) in the loop and check the output Commented May 27, 2020 at 5:08
  • @Pygirl As I said dataframe is quiet long, I am pasting one type of data from the output: Commented May 27, 2020 at 6:02
  • RYQ 2016Q1 214.919355 2016Q2 199.676471 2016Q3 230.043956 2016Q4 126.526316 2017Q1 108.934426 2017Q2 136.833333 2017Q3 172.555556 2017Q4 128.937500 2018Q1 198.485714 2018Q2 207.425000 2018Q3 297.958333 2018Q4 113.630769 2019Q1 243.130435 2019Q2 190.400000 2019Q3 197.769231 2019Q4 140.152542 2020Q1 221.043478 2020Q2 75.458333 2020Q3 199.451613 2020Q4 203.937500 Freq: Q-MAR, Name: NO_CONSUMERS, dtype: float64 Commented May 27, 2020 at 6:02

1 Answer 1

1

You need to flatten axs otherwise it is a 2d array. And you can provide the ax in plot function, see documentation of pandas plot, so using an example:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

cause_list = np.arange(9)

df = pd.DataFrame({'CAT':np.random.choice(cause_list,100),
                  'RYQ':np.random.choice(['A','B','C'],100),
                  'NO_CONSUMERS':np.random.normal(0,1,100)})

fig, axs = plt.subplots(3,3,figsize=(8,6))
axs = axs.flatten()
for index,cause in enumerate(cause_list):

    df[df['CAT']==cause].groupby('RYQ')['NO_CONSUMERS'].mean().plot(ax=axs[index])
    axs[index].set_title(cause)

plt.tight_layout()

enter image description here

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.