Matplotlib Plot, показывающий изображение в виде отдельных изображений RGB, а не одного изображения RGB - PullRequest
0 голосов
/ 01 июня 2019

У меня есть изображение, представленное в виде Numpy Array формы [224, 224, 3]. я пытаюсь построить это с помощью matplotlib, используя: plt.imshow (IMG) Но вместо того, чтобы получить одно изображение RGB, оно строит отдельные изображения R, G и B на одном графике. Куда я иду не так?

Я попытался посмотреть на форму изображения, а также несколько примеров для построения изображения. Переменная «img» имеет форму [224, 224, 3] и является типом массива numpy.

        from torchvision import datasets, transforms
        from torch.utils.data import DataLoader
        import matplotlib.pyplot as plt


        # Define Image Transform
        transform = transforms.Compose([transforms.Resize(255),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

        # Load Custom Image Dataset
        dataset = datasets.ImageFolder(root="./Cat_Dog_data", 
                                          transform=transform)

        # DataLoader
        dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

        # Get one batch of Data
        # len(images): 32
        # len(labels): 32
        # shape of images[0]: torch.Size([3, 224, 224])
        images, labels = next(iter(dataLoader))

        # img.shape: [224,224,3]
        img = images[0].numpy().reshape([224, 224, 3])

        plt.imshow(img)
        plt.show()

Я ожидаю, что изображение будет одним изображением RGB собаки или кошки. Но вывод, который я получаю, это график компонентов R, G, B этого изображения в виде столбцов на одном графике, как показано ниже. enter image description here

1 Ответ

0 голосов
/ 01 июня 2019

Код наконец-то работал с использованием функции np.transpose () вместо функции np.reshape ().

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Define Image Transform
transform = transforms.Compose([transforms.Resize(255),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

# Load Custom Image Dataset
dataset = datasets.ImageFolder(root="./Cat_Dog_data", transform=transform)

# DataLoader is a Generator
dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

# Get one batch of Data
images, labels = next(iter(dataLoader))

# Use transpose instead of reshape.
img = images[0].numpy().transpose((1, 2, 0))

plt.imshow(img)
plt.show()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...