Я использую набор данных Omniglot, который представляет собой набор из 19 280 изображений, каждое из которых имеет размер 105 x 105 (в оттенках серого).
Я определил пользовательский класс набора данных со следующим преобразованием:
class OmniglotDataset(Dataset):
def __init__(self, X, transform=None):
self.X = X
self.transform = transform
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img = self.X[idx]
if self.transform:
img = self.transform(img)
return img
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
X_train.shape
(19280, 105, 105)
train_dataset = OmniglotDataset(X_train, transform=img_transform)
Когда я индексирую одно изображение, он возвращает правильные размеры:
train_dataset[0].shape
torch.Size([1, 105, 105])
Но когда я индексирую несколько изображений, он возвращает размеры в неправильном порядке (я ожидаю 3 x 105 x 105
):
train_dataset[[1,2,3]].shape
torch.Size([105, 3, 105])