0

I would like to plot a dataframe which has 41 columns in this dataframe so there has 41 charts to plot. I write a script but it loads so slowly. Is there has a solution to optimize this script? Is it possible to use loop function to simplify the list in the zip function?

import matplotlib.pyplot as plt
import pandas as pd

fig,((axs),(axs2),(axs3),(axs4),(axs5),(axs6),(axs7),(axs8),(axs9)) = plt.subplots(9,5,figsize=(15,6))

for ax, y in zip(axs,['XPEV','M','MLCO','VIPS','HD']):
    ax.plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
    ax.ticklabel_format(style='plain', axis='y')
    ax.set_title(y)
    
    for ax, y in zip(axs2,['LVS','PTON','SBUX','BLMN','NCLH']):
        ax.plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
        ax.ticklabel_format(style='plain', axis='y')
        ax.set_title(y)
        
        for ax, y in zip(axs3,['NIO','NKE','NKLA','NLS','QS']):
            ax.plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
            ax.ticklabel_format(style='plain', axis='y')
            ax.set_title(y)
            
            for ax, y in zip(axs4,['AYRO','RMO','TSLA','XL','ASO']):
                ax.plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
                ax.ticklabel_format(style='plain', axis='y')
                ax.set_title(y)
                
                for ax, y in zip(axs5,['TOL','VSTO','BABA','FTCH','RIDE']):
                    ax.plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
                    ax.ticklabel_format(style='plain', axis='y')
                    ax.set_title(y)

                    for ax, y in zip(axs6,['EBAY','DS','DKNG','DHI','UAA']):
                        ax.plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
                        ax.ticklabel_format(style='plain', axis='y')
                        ax.set_title(y)
                        
                        for ax, y in zip(axs7,['VFC','TPX','ARVL','GM','GOEV']):
                            ax.plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
                            ax.ticklabel_format(style='plain', axis='y')
                            ax.set_title(y)
                            
                            for ax, y in zip(axs8,['PLBY','CCL','GME','CVNA','LOTZ']):
                                ax.plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
                                ax.ticklabel_format(style='plain', axis='y')
                                ax.set_title(y)
                                
                                for ax, y in zip(axs9,['F']):
                                    ax.plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
                                    ax.ticklabel_format(style='plain', axis='y')
                                    ax.set_title(y)
                                

plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
plt.show()

1 Answer 1

1
  • no sample data so I've simulated it
  • start by simplifying axis list returned by plt.subplots() to 1D
  • iterate over tickers in columns level 1 index
  • simple plot() with above steps
  • tight_layout() over compresses for me so have commented
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

tickers_data = pd.DataFrame({("Volume",t):np.random.randint(20,200, 7) for t in ['XPEV','M','MLCO','VIPS','HD']+['LVS','PTON','SBUX','BLMN','NCLH']+['NIO','NKE','NKLA','NLS','QS']+['AYRO','RMO','TSLA','XL','ASO']+['TOL','VSTO','BABA','FTCH','RIDE']+['EBAY','DS','DKNG','DHI','UAA']+
                            ['VFC','TPX','ARVL','GM','GOEV']+['PLBY','CCL','GME','CVNA','LOTZ']+['F']}, index=pd.date_range("1-Jan-2021", periods=7))

tickers_data
fig,ax = plt.subplots(9,5,figsize=(15,6))

ax = np.array(ax).flatten()
for i,y in enumerate(tickers_data.columns.get_level_values(1)):
    ax[i].plot(tickers_data.index.strftime("%d"),tickers_data['Volume',y])
    ax[i].ticklabel_format(style='plain', axis='y')
    ax[i].set_title(y)
    
# plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
plt.show()
Sign up to request clarification or add additional context in comments.

2 Comments

Thank you so much for your response. I ran the script and it works. I have a question about this script: ax = np.array(ax).flatten() May I know what does this script do? I read the document and it explains that this function use for change the array to 1 dimension, but I dont understand what does it do...
np.array([[1,2],[3,4]]).flatten() returns array([1, 2, 3, 4]). it has transformed a 2D array to a 1D array. Your solution was complicated by fact 'plt.subplots(9,5) returns a 2D array that you were then trying to navigate with many nested for loops. Just structure axis array to same structure of columns, indexed in same way and it becomes very simple :-)

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.