Создание сокращенного набора данных из существующего набора данных Torchvision - PullRequest
0 голосов
/ 31 января 2019

Мы все знаем общий набор данных MNIST, включенный в пакет torchvision.datasets.Представьте, что я хочу создать сокращенную версию этого набора данных, содержащего только 1 и 0 , чтобы классифицировать только эти два числа вместо всех 10 значений.

Я виделэти пользовательские наборы данных могут быть созданы в классе, который наследует требуемый набор данных, поэтому __getitem__, который возвращает элемент с заданным индексом.Итак, я сделал это:

class MNIST01(MNIST):
    def __getitem__(self, idx):
        image, label = super().__getitem__(idx)
        if label.item() <= 1:
            return image, label
        else:
            return None

Проблема в том, что мне кажется, что я не могу вернуть значение None, так как оно должно быть "содержать тензоры, числа, диктовки или списки; найденный класс NoneType"'".

Есть ли простой способ получить уменьшенную версию этого набора данных аналогичным образом?

1 Ответ

0 голосов
/ 03 февраля 2019

Мне наконец удалось справиться с проблемой NoneType.Сохранение функции, определенной в вопросе.

class MNIST01(MNIST):
    def __getitem__(self, idx):
        features, target = super(MNIST01, self).__getitem__(idx)
        if target.item() <= 1:
            return features, target

Теперь нам нужно определить пользовательскую функцию сортировки collate_fn для нашего загрузчика данных, которая обрабатывает список выборок для формирования партии.В этой функции мы можем применить фильтр для обработки None значений и игнорировать их.

from torch.utils.data.dataloader import default_collate

def filter_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return default_collate(batch)

Тогда нам просто нужно передать эту функцию в DataLoader:

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, collate_fn=filter_collate, **kwargs)
test_loader = DataLoader(test_dataset, collate_fn=filter_collate, **kwargs)

Версия 2

Гораздо проще, чем первый, избегая некоторых проблем при доступе к данным.Просто отфильтруйте непосредственно атрибуты train_data и train_label (и соответствующие для набора тестов) из экземпляра класса MNIST.

train_dataset.train_data = train_dataset.train_data[train_dataset.train_labels <= 1]
train_dataset.train_labels = train_dataset.train_labels[train_dataset.train_labels <= 1]
...