1

I am trying to reproduce a figure from a paper in matplotlib, shown below. Basically, each cell has a percentage in it, and the higher the percentage, the darker the background of the cell:

enter image description here

The code below produces something similar, but each cell is a square pixel and I would like for them to be flatter rectangles rather than squares, as in the above image. How can I achieve this with matplotlib?

import numpy as np
import matplotlib.pyplot as plt
import itertools

table = np.random.uniform(low=0.0, high=1.0, size=(10,5))

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure()
plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1)
plt.yticks(np.arange(10), class_names)

for i,j in itertools.product(range(table.shape[0]), range(table.shape[1])):
    plt.text(j, i, format(table[i,j], '.2f'),
             horizontalalignment="center",
             color="white" if table[i,j] > 0.5 else "black")

plt.show()

Here's what this code above produces: enter image description here

I suspect that changing the aspect and the extent for imshow may help me here. I don't fully understand how this works, but here's what I've tried:

plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1, aspect='equal', extent=[0,14,10,0])

This produces the following: enter image description here

I realise that I also need to add the borders between the cells, remove the tick marks, and change the values to percentages rather than decimals, and I am confident that I will be able to do this by myself, but if you want to help me out with that too then please feel free!

4
  • 1
    Maybe this stackoverflow link Imshow: extent and aspect answers your question Commented Oct 19, 2017 at 8:46
  • 1
    @Spezi94 Thanks, I saw this answer and tried it but I can't work out how the text placement works after the aspect & extent have been changed. I'll update the question to reflect this Commented Oct 19, 2017 at 8:48
  • 1
    Using extent=[0,14,0,10] gives a nice looking figure (for me at least). Though I also have't figured out how to place the text..... Commented Oct 19, 2017 at 8:54
  • @DavidG Looks nice, thanks! I've just noticed that the order of the labels is reversed, and this can be fixed by using extent=[0,14,10,0] instead. Commented Oct 19, 2017 at 8:59

2 Answers 2

4

You will get non-square pixels when using aspect="auto" in the imshow call:

import numpy as np
import matplotlib.pyplot as plt
import itertools

table = np.random.uniform(low=0.0, high=1.0, size=(10,5))

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure()
plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1, aspect="auto")
plt.yticks(np.arange(10), class_names)

for i,j in itertools.product(range(table.shape[0]), range(table.shape[1])):
    plt.text(j, i, format(table[i,j], '.2f'),
             ha="center", va="center",
             color="white" if table[i,j] > 0.5 else "black")

plt.show()

enter image description here

Sign up to request clarification or add additional context in comments.

1 Comment

This is much simpler than my solution. Thanks!
2

After lots of experimentation, I figured out how the extent works and how this effects the coordinates for text later on. I also added the borders etc, and this code generates a pretty good replica of the original style!

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import itertools

table = np.random.uniform(low=0.0, high=1.0, size=(10,5))

class_names = ['airplane', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure()

plt.imshow(table, interpolation='nearest', cmap=plt.cm.Greys, vmin=0, vmax=1, aspect='equal', extent=[0,14,10,0])
plt.yticks(np.arange(10)+0.5, class_names)
plt.xticks(np.arange(5)*2.8 + 1.4, ['1', '2', '3', '4', '5'])
ax = plt.axes()
ax.yaxis.set_ticks_position('none')
ax.xaxis.set_ticks_position('none')

matplotlib.rcParams.update({'font.size': 14})

ax = plt.gca()

# Minor ticks
ax.set_xticks(np.arange(1, 5) * 2.8, minor=True);
ax.set_yticks(np.arange(1, 10, 1), minor=True);

# Gridlines based on minor ticks
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)

for i,j in itertools.product(range(table.shape[0]), range(table.shape[1])):
    plt.text(j*2.8+1.5, i+0.6, format(table[i,j], '.2f'),
             horizontalalignment="center",
             color="white" if table[i,j] > 0.5 else "black")

plt.show()

enter image description here

Thanks to DavidG and Spezi94 who helped with their comments!

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.