1

I have a plot with multiple images. There are 20 rows and 3 columns. In each row the 1st column should be the original image, the 2nd column should be the mask and the 3rd column should be segmentation obtained using a neural network.

I tried using the below code. 2 rows can be plotted correctly. From 3rd row images overlap the 1st row.

f = plt.figure()
for j in range(3):
    pred_y = y_pred[j]
    pred_y = pred_y.reshape(256, 256)
    pred_y = (pred_y > 0.1).astype(np.uint8)
    f.add_subplot(j + 1, 3, 1 )
    plt.imshow(X_test[j, :, :])

    f.add_subplot(j + 1, 3, 2)
    plt.imshow(y_test[j])
    
    f.add_subplot(j + 1, 3, 3)
    plt.imshow(pred_y)
plt.show()

The obtained output is,

enter image description here

3

1 Answer 1

1

The arguments to Figure.add_subplot are num_rows, num_cols, n, not row, num_rows, col. You are therefore making the top row of plots in a 1x3, 2x3, 3x3 arrangement. Make your calls to add_subplot like this:

f.add_subplot(20, 3, 3 * j + 1)
f.add_subplot(20, 3, 3 * j + 2)
f.add_subplot(20, 3, 3 * j + 1)

A better way to generate your axes might be plt.subplots, as suggested in the demo:

fig, ax = plt.subplots(20, 3, constrained_layout=True)
for i in range(20):
    ax[i, 0].imshow(X_test[i])
    ax[i, 1].imshow(y_test[j])
    ax[i, 2].imshow((pred_y[j].reshape(256, 256) > 0.1).astype(np.uint8))
Sign up to request clarification or add additional context in comments.

2 Comments

It really helped. Thank you. Bu the size is very small. Is there any way to increase that. Corresponding output can be seen in the link drive.google.com/file/d/1nHrP3oHp0HDu-bwlJ05kJ6MdnYh3gRO2/…
@Aleena. Don't display 20 rows all at once. Maybe call it off at 3 per figure or a separate figure for each image

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.