Проблема
В PyTorch я пытаюсь написать класс, который мог бы возвращать все data
и label
отдельно, используя синтаксис, такой как dataset.data
и dataset.label
. Скелет кода выглядит следующим образом:
class MyDataset(object):
data = _get_data()
label = _get_label()
def __init__(self, dir, transforms):
self.img_list = ... # all image paths loaded from dir
# do something
def __getitem__(self):
# do something
return data, label
def __len__(self):
return len(self.img_list)
def _get_data():
# do something
def _get_label():
# do something
Однако, когда я использую dataset.data
и dataset.label
для доступа к соответствующим переменным, ничего не возвращается.
Мне интересно, почему это так и как я могу это исправить.
Редактировать
Спасибо за все ваше внимание.
У меня естьрешил эту проблему сам. Решение довольно простое, которое просто использует свойство переменных класса .
class FaceDataset(object):
# class variable
data = None
label = None
def __init__(self, root, transforms=None):
# read img_list from root
img_list = ...
self.transforms = ...
FaceDataset.data = FaceDataset._get_data(self.img_list, self.transforms)
FaceDataset.label = FaceDataset._get_label(self.img_list)
@classmethod
def _get_data(cls, img_list, transforms):
data_list = []
for img_path in img_list:
data_list.append(transforms(Image.open(img_path)).unsqueeze(0))
return torch.stack(data_list, dim=0)
@classmethod
def _get_label(cls, img_list):
label = torch.zeros(len(img_list))
for i, img_path in enumerate(img_list):
label[i] = ...
return label
def __getitem__(self, index):
img_path = self.img_list[index]
label = ...
# read image from file
data = Image.open(img_path)
# apply transform defined in __init__
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.img_list)