Несоответствие формы вывода и трансляции в MNIST, torchvision - PullRequest
1 голос
/ 12 марта 2019

Я получаю следующую ошибку при использовании набора данных MNIST в Torchvision

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

Вот мой код:

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                          ])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
images, labels = next(iter(trainloader))

1 Ответ

10 голосов
/ 12 марта 2019

Ошибка из-за цвета в градациях серого в наборе данных, набор данных в градациях серого.

Я исправил это, изменив преобразование на

transform = transforms.Compose([transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])
...