Создание собственного набора данных в PyTorch - PullRequest
0 голосов
/ 22 октября 2019

Проблема

В PyTorch я пытаюсь написать класс, который мог бы возвращать все data и label отдельно, используя синтаксис, такой как dataset.data и dataset.label. Скелет кода выглядит следующим образом:

class MyDataset(object):
  data = _get_data()
  label = _get_label()
  def __init__(self, dir, transforms):
    self.img_list = ... # all image paths loaded from dir
    # do something 

  def __getitem__(self):
    # do something
    return data, label

  def __len__(self):
    return len(self.img_list)

  def _get_data():
    # do something

  def _get_label():
    # do something

Однако, когда я использую dataset.data и dataset.label для доступа к соответствующим переменным, ничего не возвращается.

Мне интересно, почему это так и как я могу это исправить.

Редактировать

Спасибо за все ваше внимание.

У меня естьрешил эту проблему сам. Решение довольно простое, которое просто использует свойство переменных класса .

class FaceDataset(object):
    # class variable
    data = None
    label = None

    def __init__(self, root, transforms=None):
        # read img_list from root
        img_list = ...
        self.transforms = ...
        FaceDataset.data = FaceDataset._get_data(self.img_list, self.transforms)
        FaceDataset.label = FaceDataset._get_label(self.img_list)

    @classmethod
    def _get_data(cls, img_list, transforms):
        data_list = []
        for img_path in img_list:
            data_list.append(transforms(Image.open(img_path)).unsqueeze(0))
        return torch.stack(data_list, dim=0)

    @classmethod
    def _get_label(cls, img_list):
        label = torch.zeros(len(img_list))
        for i, img_path in enumerate(img_list):
            label[i] = ...
        return label

    def __getitem__(self, index):
        img_path = self.img_list[index]
        label = ...

        # read image from file
        data = Image.open(img_path)
        # apply transform defined in __init__
        data = self.transforms(data)

        return data, label

    def __len__(self):
        return len(self.img_list)

1 Ответ

1 голос
/ 25 октября 2019

На "нормальный" способ создания пользовательских наборов данных в Python уже получен ответ здесь в SO. Для этого есть официальное руководство по PyTorch .

Для простого примера вы можете прочитать код набора данных PyTorch MNIST здесь (этот набор данных используется в этомPyTorch пример кода для дальнейшей иллюстрации). Наконец, вы можете найти другие реализации набора данных в этих наборах данных torchvision список (нажмите на имя набора данных, затем на кнопку «источник» в документации набора данных, чтобы получить доступ к реализации PyTorch набора данных).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...