11

in my Jupyter notebook I am trying to display an image that I am iterating on through Keras. The code I am using is as below

def plotImages(path, num):
 batchGenerator = file_utils.fileBatchGenerator(path+"train/", num)
 imgs,labels = next(batchGenerator)
 fig = plt.figure(figsize=(224, 224))
 plt.gray()
 for i in range(num):
    sub = fig.add_subplot(num, 1, i + 1)
    sub.imshow(imgs[i,0], interpolation='nearest')

But this only plots single channel, so my image is grayscale. How do I use the 3 channels to output a colour image plot. ?

3
  • What is the shape of imgs? Commented Jan 25, 2017 at 16:33
  • 224*224*3, in the above code I am printing channel 1 Commented Jan 25, 2017 at 16:37
  • 2
    So just do sub.imshow(imgs) to use all channels. How do you expect it to show RGB if you don't provide all channels? Commented Jan 25, 2017 at 16:38

2 Answers 2

11

If you want to display an RGB image, you have to supply all three channels. Based on your code, you are instead displaying just the first channel so matplotlib has no information to display it as RGB. Instead it will map the values to the gray colormap since you've called plt.gray()

Instead, you'll want to pass all channels of the RGB image to imshow and then the true color display is used and the colormap of the figure is disregarded

sub.imshow(imgs, interpolation='nearest')

Update

Since imgs is actually 2 x 3 x 224 x 224, you'll want to index into imgs and permute the dimensions to be 224 x 224 x 3 prior to displaying the image

im2display = imgs[1].transpose((1,2,0))
sub.imshow(im2display, interpolation='nearest')
Sign up to request clarification or add additional context in comments.

4 Comments

I am getting an error , 'TypeError: Invalid dimensions for image data'
@Abhik You said your image was 244 x 244 x 3. What does imgs.shape show?
(2, 3, 224, 224) - 2 images of 3 *224*224 . Hence I am doing sub.imshow(imgs[i], interpolation='nearest') , where i is the index of the image .
TypeError: transpose() received an invalid combination of arguments - got (tuple), but expected one of: * (int dim0, int dim1) * (name dim0, name dim1)
0

It will work for you: Try putting the channels last by permitting,

image.permute(2 , 3 , 1 , 0)

and then remove the image index by np.squeeze():

plt.imshow((np.squeeze(image.permute(2 , 3 , 1 , 0))))

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.