Вместо того чтобы использовать встроенные в Pytorch API-интерфейсы для этих наборов данных, я пытаюсь создать свой собственный набор данных и передать его в API-интерфейс DATASET и API-интерфейс DATALOADER Pytorch. Но почему-то я сталкиваюсь с некоторой ошибкой.
Мои данные в этом формате, который я создал, объединив все 4 рассола в один. ИЗОБРАЖЕНИЯ ЭТИКЕТКИ
После создания данных и после этого [CustomDataset] [3] я написал следующий код:
import numpy as np
import pickle as pkl
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# For custom dataset inherit the parent Dataset class into the child class
class CIFARDataset(Dataset):
"""CIFAR dataset."""
def __init__(self, pckl_path, transform=None):
"""
:param pckl_path:
:param transform:
"""
" Load the pickle files data"
pckl_fd = open(pckl_path, "rb")
self.data_pckl = pkl.load(pckl_fd)
self.transform = transform
def __len__(self):
return len(self.data_pckl)
def __getitem__(self, idx):
print("inside __get_item")
if torch.is_tensor(idx):
idx = idx.tolist()
sample = {'image': self.data_pckl['images'][idx], 'label': self.data_pckl['labels'][idx]}
if self.transform:
sample = self.transform(sample)
return sample
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
print("In ToTensor")
image, label = sample['images'], sample['labels']
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),
'label': torch.from_numpy(np.ndarray(label))}
dataset= CIFARDataset('cifar/train_set.pickle', transform=transforms.Compose(ToTensor()))
# composed = transforms.Compose([ToTensor()])
# sample = dataset.data_pckl
sample1 = {'images':None, 'labels': None}
data = dataset[0]
Когда яЗапустив это, я получаю следующую ошибку:
Ошибка:
data = dataset[0]
File "/home/garud/Documents/DSP_notes/Project/create_dataset.py", line 34, in __getitem__
sample = self.transform(sample)
File "/home/garud/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 60, in __call__
for t in self.transforms:
TypeError: 'ToTensor' object is not iterable
Я отладил и проверил образец - это dict, который передается в функцию преобразования. Не знаю, где это идет не так.
Добрый совет, что не так и какие лучшие практики необходимо соблюдать, чтобы быть лучше в этом.