Как я могу перетасовать метки набора данных? - PullRequest
1 голос
/ 23 октября 2019

Я скачал набор данных MNIST, используя следующую команду:

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

Теперь мне нужно провести несколько экспериментов с этим набором данных (MNIST), но перетасовывать метки обучающего набора. Как я могу перетасовать / переназначить их случайным образом? Я пробовал следующее:

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            target_transform=lambda y: torch.randint(0, 10, (1,)).item(),
                            download=True)

Но я заметил, что то, что происходит после лямбда-функции, заставляет метки тасоваться в процессе обучения, например, они меняются в каждую эпоху. Таким образом, я не достигну 100% точности обучения, к чему я стремлюсь. Как я могу перетасовать эти метки совершенно случайно, убедившись, что эти метки не будут меняться в процессе обучения?

Спасибо !!

1 Ответ

1 голос
/ 23 октября 2019

В случае, если ваша цель состоит в том, чтобы создать случайное отображение меток, вам нужно определить это отображение перед определением целевого преобразования, чтобы сохранить его постоянным. Что-то вроде следующего должно сделать трюк

import random
label_mapping = list(range(10))
random.shuffle(label_mapping)
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            target_transform=lambda y: label_mapping[y],
                            download=True)

Для того, чтобы получить новую случайную смену в каждой эпохе, вы захотите переопределить отображение меток, набор обучающих данных и загрузчик данных для каждой эпохи.

Обновление Чтобы вместо этого сгенерировать случайную метку, которая не зависит от истинной метки, но совместима с данным индексом, вам, вероятно, нужно либо выполнить очень тщательное заполнение, либо переопределить некоторые функции класса набора данных.

Например, последний случай может выглядеть примерно так:

import random
class RandomMNIST(dsets.MNIST):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.targets = [random.randint(0, 9) for _ in range(len(self.data))]

train_dataset = RandomMNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

или эквивалентно

import random
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)
train_dataset.targets = [random.randint(0, 9) for _ in range(len(train_dataset))]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...