Я пытаюсь использовать набор данных MNIST из torchvision.datasets
. Кажется, он предоставляется как тензор N x H x W (uint8)
(размер партии, высота, ширина).Однако для всех классов pytorch для работы с изображениями (например, Conv2d
) требуется тензор N x C x H x W (float32)
, где C
- количество цветовых каналов.Я попытался добавить добавление преобразования ToTensor
, но это не добавило цветной канал.
Есть ли способ использовать torchvision.transforms
для добавления этого дополнительного измерения?Для необработанного tensor
мы могли бы просто сделать .unsqueeze(1)
, но это не выглядит как очень элегантное решение.Я просто пытаюсь сделать это "правильным" способом.
Вот неудачное преобразование.
import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])