Мне наконец удалось справиться с проблемой NoneType.Сохранение функции, определенной в вопросе.
class MNIST01(MNIST):
def __getitem__(self, idx):
features, target = super(MNIST01, self).__getitem__(idx)
if target.item() <= 1:
return features, target
Теперь нам нужно определить пользовательскую функцию сортировки collate_fn
для нашего загрузчика данных, которая обрабатывает список выборок для формирования партии.В этой функции мы можем применить фильтр для обработки None
значений и игнорировать их.
from torch.utils.data.dataloader import default_collate
def filter_collate(batch):
batch = list(filter(lambda x: x is not None, batch))
return default_collate(batch)
Тогда нам просто нужно передать эту функцию в DataLoader
:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, collate_fn=filter_collate, **kwargs)
test_loader = DataLoader(test_dataset, collate_fn=filter_collate, **kwargs)
Версия 2
Гораздо проще, чем первый, избегая некоторых проблем при доступе к данным.Просто отфильтруйте непосредственно атрибуты train_data
и train_label
(и соответствующие для набора тестов) из экземпляра класса MNIST
.
train_dataset.train_data = train_dataset.train_data[train_dataset.train_labels <= 1]
train_dataset.train_labels = train_dataset.train_labels[train_dataset.train_labels <= 1]