2

Is there a way to do a grid with scatterplots from all columns from a dataframe, where Y is one of the dataframe columns?

I can do a for loop on either matplotlib or seabornfor this (see codes below), but I can't make them show on a grid.

I want them to be displayed in grid visualization to make it easier to compare them.

This is what I CAN do:

for col in boston_df:
    plt.scatter(boston_df[col], boston_df["MEDV"], c="red", label=col)
    plt.ylabel("medv")
    plt.legend()
    plt.show()

or

for col in boston_df:
    sns.regplot(x=boston_df[col], y=boston_df["MEDV"])
    plt.show()

Now if I try to create a subplot for example and use ax.scatter() in my loop like this

fig, ax = plt.subplots(3, 5,figsize=(16,6))
for col in boston_df:
    ax.scatter(boston_df[col], boston_df["MEDV"], c="red", label=col)
    plt.ylabel("medv")
    plt.legend()
    plt.show()

it gives me the error AttributeError: 'numpy.ndarray' object has no attribute 'scatter'

It would be beautiful to find some solution simple like this:

df.hist(figsize=(18,10), density=True, label=df.columns)
plt.show()
0

1 Answer 1

4

Consider using the ax argument of pandas DataFrame.plot and seaborn's regplot:

fig, ax = plt.subplots(1, 5, figsize=(16,6))

for i,col in enumerate(boston_df.columns[1:]):
     #boston_df.plot(kind='scatter', x=col, y='MEDV', ax=ax[i])
     sns.regplot(x=boston_df[col], y=boston_df["MEDV"], ax=ax[i])

fig.suptitle('My Scatter Plots')
fig.tight_layout()
fig.subplots_adjust(top=0.95)      # TO ACCOMMODATE TITLE

plt.show()

To demonstrate with random data:

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

### DATA BUILD
np.random.seed(6012019)
random_df = pd.DataFrame(np.random.randn(50,6), 
                         columns = ['MEDV', 'COL1', 'COL2', 'COL3', 'COL4', 'COL5'])

### PLOT BUILD
fig, ax = plt.subplots(1, 5, figsize=(16,6))

for i,col in enumerate(random_df.columns[1:]):
     #random_df.plot(kind='scatter', x=col, y='MEDV', ax=ax[i])
     sns.regplot(x=random_df[col], y=random_df["MEDV"], ax=ax[i])

fig.suptitle('My Scatter Plots')
fig.tight_layout()
fig.subplots_adjust(top=0.95)

plt.show()
plt.clf()
plt.close()

Plot Output

For multiple rows across multiple columns, adjust the assignment to ax which is a numpy array using indexes: ax[row_idx, col_idx].

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

### DATA BUILD
np.random.seed(6012019)
random_df = pd.DataFrame(np.random.randn(50,14), 
                         columns = ['MEDV', 'COL1', 'COL2', 'COL3', 'COL4', 
                                    'COL5', 'COL6', 'COL7', 'COL8', 'COl9', 
                                    'COL10', 'COL11', 'COL12', 'COL13'])

### PLOT BUILD
fig, ax = plt.subplots(2, 7, figsize=(16,6))

for i,col in enumerate(random_df.columns[1:]):
     #random_df.plot(kind='scatter', x=col, y='MEDV', ax=ax[i])
     if i <= 6:
        sns.regplot(x=random_df[col], y=random_df["MEDV"], ax=ax[0,i])
     else:
        sns.regplot(x=random_df[col], y=random_df["MEDV"], ax=ax[1,i-7])     

ax[1,6].axis('off')                  # HIDES AXES ON LAST ROW AND COL

fig.suptitle('My Scatter Plots')
fig.tight_layout()
fig.subplots_adjust(top=0.95)

plt.show()
plt.clf()
plt.close()

Multiple Rows Subplots

Sign up to request clarification or add additional context in comments.

4 Comments

This code gives me scatterplots only for the first 5 columns of the data (skipping the first one which is the target), but I have 13 columns in there. (it also throws the error IndexError: index 5 is out of bounds for axis 0 with size 5, although it still displays the plots.) If I try to do 2 rows changing the fig, ax = plt.subplots(2, 5, figsize=(16,6)) from 1 to 2 rows, then I get the error AttributeError: 'numpy.ndarray' object has no attribute 'scatter'
You need to adjust subplots to accommodate all your columns. Hence the IndexError. This solution assumes MEDV is first column and that you do NOT want to run scatterplot on itself. Simply adjust ncol of 5 in plt.subplots() to 13. And for multiple rows see extended answer where assignment to ax must be adjusted.
It works with 2, 7 - I can't make it work with 3, 5... I'm sure it's just that, being a beginner, I don't fully understand what's going on with the code and can't thus make the adjustments myself. I'll keep trying, I probably have to add an elif if I want 3 rows? In any case the solution above with 2, 7 did work!
Really? There is nothing preventing this from working in either version. I never worked in Python 2! Maybe your libraries or environments are mixed. In future, post actual data sample for us to help. Happy coding!

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.