Добавить канал в MNIST через преобразование? - PullRequest
0 голосов
/ 15 февраля 2019

Я пытаюсь использовать набор данных MNIST из torchvision.datasets. Кажется, он предоставляется как тензор N x H x W (uint8) (размер партии, высота, ширина).Однако для всех классов pytorch для работы с изображениями (например, Conv2d) требуется тензор N x C x H x W (float32), где C - количество цветовых каналов.Я попытался добавить добавление преобразования ToTensor, но это не добавило цветной канал.

Есть ли способ использовать torchvision.transforms для добавления этого дополнительного измерения?Для необработанного tensor мы могли бы просто сделать .unsqueeze(1), но это не выглядит как очень элегантное решение.Я просто пытаюсь сделать это "правильным" способом.

Вот неудачное преобразование.

import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])

1 Ответ

0 голосов
/ 15 февраля 2019

У меня было неправильное представление: dataset.train_data - это , а не , на которое влияет указанное transform, будет только вывод DataLoader(dataset,...).После проверки data из

for data, _ in DataLoader(dataset):
    break

мы видим, что ToTensor на самом деле делает именно то, что нужно.

...