I have this error: TypeError: Invalid shape (28, 28, 1) for image data
Here is my code:
import torch
import torchvision
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
%matplotlib inline
# Load dataset
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz
from torchvision.datasets import MNIST
dataset = MNIST(root = './', train=True, download=True, transform=ToTensor())
#val_data = MNIST(root = './', train=False, download=True, transform=transform)
image, label = dataset[0]
print('image.shape:', image.shape)
plt.imshow(image.permute(1, 2, 0), cmap='gray') # HELP WITH THIS LINE
print('Label:', label)
I know that the pytorch does processing via this way: C x H x W, and that matplotlib does it this way: H x W x C, yet when I change it to matplotlib's way, it gives me an error here. Am I missing something? Why does this happen?
np.squeeze()tfrather thanpt. Updated.