0

I have this error: TypeError: Invalid shape (28, 28, 1) for image data

Here is my code:

import torch
import torchvision
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
%matplotlib inline

# Load dataset

!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

from torchvision.datasets import MNIST

dataset = MNIST(root = './', train=True, download=True, transform=ToTensor())
#val_data = MNIST(root = './', train=False, download=True, transform=transform)

image, label = dataset[0]
print('image.shape:', image.shape)
plt.imshow(image.permute(1, 2, 0), cmap='gray') # HELP WITH THIS LINE
print('Label:', label)

I know that the pytorch does processing via this way: C x H x W, and that matplotlib does it this way: H x W x C, yet when I change it to matplotlib's way, it gives me an error here. Am I missing something? Why does this happen?

4
  • 1
    Try to squeeze the last dimension using np.squeeze() Commented Mar 29, 2021 at 9:54
  • Thanks this worked! Can you tell me the rationale behind np.squeeze()? Why does using this method work? Commented Mar 29, 2021 at 10:00
  • I added an answer with explanation, I suggest deleting yours and kindly accepting my answer :) Commented Mar 29, 2021 at 10:05
  • @zampoan it's weird that you tagged tf rather than pt. Updated. Commented Mar 29, 2021 at 10:12

2 Answers 2

1

plt.imshow() expects 2D or 3D arrays. If the array has 3 dimensions then the last dimension should be 3 or 4. In your case the array has shape (28,28,1) and this is considered as a 3D array.

So the last dimension should be squeezed out in order to match imshow()'s requirements.

plt.imshow(np.squeeze(image.permute(1, 2, 0), axis = 2), cmap='gray')
Sign up to request clarification or add additional context in comments.

1 Comment

@zampoan this is the right and only acceptable answer. Please don't repeat the same post.
0

The accepted answer is correct, but doesn't use the shortest path. image has a shape of [1,28,28]. image.permute(1,2,0) turns it into [28, 28, 1] which then requires squeeze to turn into a [28, 28] shape. But a shorter answer is to just use image[0]

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.