Как извлечь только подмножество классов из torchvision.datasets.CIFAR10? - PullRequest
0 голосов
/ 26 января 2019

Как извлечь только 2 или 3 класса из torchvision.datasets.CIFAR10?

Стандартный способ загрузки всех 10 классов

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

1 Ответ

0 голосов
/ 26 января 2019

Изучив код из CIFAR10, вы увидите, что данные хранятся в виде массива numpy, а метки - в виде списка.Поэтому вы можете создать подкласс этого класса и соответствующим образом отфильтровать два массива.Пример ниже:

class SubLoader(torchvision.datasets.CIFAR10):
    def __init__(self, *args, exclude_list=[], **kwargs):
        super(SubLoader, self).__init__(*args, **kwargs)

        if exclude_list == []:
            return

        if self.train:
            labels = np.array(self.train_labels)
            exclude = np.array(exclude_list).reshape(1, -1)
            mask = ~(labels.reshape(-1, 1) == exclude).any(axis=1)

            self.train_data = self.train_data[mask]
            self.train_labels = labels[mask].tolist()
        else:
            labels = np.array(self.test_labels)
            exclude = np.array(exclude_list).reshape(1, -1)
            mask = ~(labels.reshape(-1, 1) == exclude).any(axis=1)

            self.test_data = self.test_data[mask]
            self.test_labels = labels[mask].tolist()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...