0

I'm trying (and failing) to obtain a nested boxplot starting from a numpy array with dimension 3, for example A = np.random.uniform(size = (4,100,2)).

The kind of plot I'm referring to is represented in the next picture, which comes from the seaborn boxplot docs.

nested boxplot

0

1 Answer 1

0

You can use np.meshgrid() to generate 3 columns which index the 3D array. Unraveling these arrays makes them suitable as input for seaborn. Optionally, these arrays can be converted to a dataframe, which helps in automatically generating labels.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

A = np.random.normal(0.02, 1, size=(4, 100, 2)).reshape(-1).cumsum().reshape(4, -1, 2)
x_names = ['A', 'B', 'C', 'D']
hue_names = ['x', 'y']
dim1, dim2, dim3 = np.meshgrid(x_names, np.arange(A.shape[1]), hue_names, indexing='ij')
sns.boxplot(x=dim1.ravel(), y=A.ravel(), hue=dim3.ravel())

plt.tight_layout()
plt.show()

boxplot from 3d array

To create a dataframe, the code could look like the following. Note that the numeric second dimension isn't needed explicitly for the boxplot.

df = pd.DataFrame({'dim1': dim1.ravel(),
                   'dim2': dim2.ravel(),
                   'dim3': dim3.ravel(),
                   'A': A.ravel()})

# some tests to be sure that the 3D array has been interpreted well
assert (A[0, :, 0].sum() == df[(df['dim1'] == 'A') & (df['dim3'] == 'x')]['values'].sum())
assert (A[2, :, 1].sum() == df[(df['dim1'] == 'C') & (df['dim3'] == 'y')]['values'].sum())

sns.boxplot(data=df, x='dim1', y='A', hue='dim3')

If the array or the names are very long, working fully numeric would use less memory and speed things up:

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

A = np.random.normal(0.01, 1, size=(10, 1000, 2)).reshape(-1).cumsum().reshape(10, -1, 2)
dim1, dim2, dim3 = np.meshgrid(np.arange(A.shape[0]), np.arange(A.shape[1]), np.arange(A.shape[2]), indexing='ij')
sns.set_style('whitegrid')
ax = sns.boxplot(x=dim1.ravel(), y=A.ravel(), hue=dim3.ravel(), palette='spring')
ax.set_xticklabels(["alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta", "iota", "kappa"])
ax.legend(handles=ax.legend_.legendHandles, labels=['2019-2020', '2020-2021'], title='Year')
sns.despine()
plt.tight_layout()
plt.show()

longer example for sns.boxplot from a 3d array

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.