Как отобразить одно изображение в PyTorch? - PullRequest
0 голосов
/ 05 декабря 2018

Я хочу отобразить одно изображение.Он был загружен с помощью ImageLoader и хранится в PyTorch Tensor.

Когда я пытаюсь отобразить его с помощью plt.imshow(image), я получаю:

TypeError: Invalid dimensions for image data

.shape тензора:

torch.Size([3, 244, 244])

Как отобразить изображение, содержащееся в этом тензоре PyTorch?

Ответы [ 4 ]

0 голосов
/ 16 марта 2019

Учитывая Tensor, представляющее изображение, используйте .permute():

plt.imshow(  tensor_image.permute(1, 2, 0)  )

Примечание: permute не копирует и не выделяет память from_numpy() тоже нет.

0 голосов
/ 05 декабря 2018

Как видите, matplotlib отлично работает даже без преобразования в массив numpy.Но PyTorch Tensors («Тензор изображений») - это сначала каналы, поэтому для использования их с matplotlib вам необходимо изменить его:

Код:

from scipy.misc import face
import matplotlib.pyplot as plt
import torch

np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)

# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)

# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)

plt.imshow(tensor_image)
plt.show()

Вывод:

<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
0 голосов
/ 13 марта 2019

Полный пример с указанием пути к изображению img_path:

from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")

Обратите внимание, что transforms.* возвращает функцию, вот почему фанки заключают в скобки.

0 голосов
/ 05 декабря 2018

Учитывая, что изображение загружается, как описано, и сохраняется в переменной image:

plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")

В учебнике matplotlib изображения написано:

Бикубическая интерполяция часто используется при взрыве фотографий - люди, как правило, предпочитают размытые, а не пиксельные.


Или, как предположил Сумит :

%matplotlib inline
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')

Или, чтобы открыть изображение во всплывающем окне:

 transforms.ToPILImage()(image).show()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...