Pytorch - невозможно разрезать набор данных MNIST - PullRequest
0 голосов
/ 18 января 2019

В Pytorch при использовании набора данных torchvision MNIST мы можем получить следующую цифру:

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset

tsfm = transforms.Compose([transforms.Resize((16, 16)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))])

mnist_ds = torchvision.datasets.MNIST(root='../../../_data/mnist',train=True,
                                download=True, transform=tsfm)


digit_12 = mnist_ds[12]

Хотя на большинство наборов данных можно нарезать, мы не можем нарезать на этот:

digit_12_to_14 = mnist_ds[12:15]

вернется

ValueError: Too many dimensions: 3 > 2.

Это связано с тем, что Image.fromarray () в getItem ()

Можно ли использовать набор данных MNIST без использования загрузчика данных? Как?

PS: Причина, по которой я хотел бы избежать использования Dataloader, заключается в том, что отправка пакетов по одному в GPU замедляет обучение. Я предпочитаю отправлять в GPU целые данные только один раз. Для этого мне нужен доступ ко всему преобразованному набору данных.

Ответы [ 2 ]

0 голосов
/ 18 января 2019

Я нашел 2 решения для преобразования набора данных torchvision MNIST в тензоры. Первый получен из комментария Фабио Переса:

print("\nFirst...")
st = time()
x_all_ts = torch.tensor([mnist_ds[i][0].numpy() for i in range(0, len(mnist_ds))])
t_all_ts = mnist_ds.train_labels
print(f"{time()-st}   images:{x_all_ts.size()}  targets:{t_all_ts.size()} ")

print("\nSecond...")
st = time()
mnist_dl = DataLoader(dataset=mnist_ds, batch_size=len(mnist_ds))
x_all_ts2, t_all_ts2 = list(mnist_dl)[0]
print(f"{time()-st}   images:{x_all_ts2.size()}  targets:{t_all_ts2.size()} ")


First...
19.573785066604614   images:torch.Size([60000, 1, 16, 16])  targets:torch.Size([60000]) 
Second...
16.826476573944092   images:torch.Size([60000, 1, 16, 16])  targets:torch.Size([60000]) 

Пожалуйста, дайте мне знать, если вы найдете лучшие.

0 голосов
/ 18 января 2019

Для интерфейса Dataset требуется только

Все подклассы должны переопределять __len__, который обеспечивает размер набора данных, и __getitem__, поддерживающий целочисленную индексацию в диапазоне от 0 до len(self) exclusive.

, который явно не упоминает нарезку - поведение нарезки других наборов данных является дополнительной функцией. Если вы хотите получить все данные сразу, вы можете посмотреть реализацию и просто использовать тензоры mnist.data и mnist.targets, определенные к концу __init__.

Если вы хотите преобразовать данные, вы можете использовать

data = [mnist_ds[i] for i in range(len(mnist_ds))]
xs = torch.stack([d[0] for d in data], dim=0)
ys = torch.stack([d[1] for d in data], dim=0)

или преобразовать тензор mnist.data сразу (хотя это не будет работать с преобразованиями torchvision.transform).

...