PyTorch DataLoader shuffle - PullRequest
       65

PyTorch DataLoader shuffle

0 голосов
/ 09 апреля 2020

Я провел эксперимент и не получил ожидаемого результата.

Для первой части я использую

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=False, num_workers=0)

Я сохраняю trainloader.dataset.targets в переменной a и trainloader.dataset.data в переменной b перед тренировкой моей модели. Затем я тренирую модель, используя trainloader. После окончания обучения я сохраняю trainloader.dataset.targets в переменной c и trainloader.dataset.data в переменной d. Наконец, я проверяю a == c и b == d, и они оба дают True, что ожидалось, потому что параметр shuffle DataLoader равен False.

Для второй части я использую

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=True, num_workers=0)

Я сохраняю trainloader.dataset.targets в переменной e и trainloader.dataset.data в переменной f перед тренировкой моей модели. Затем я тренирую модель, используя trainloader. После окончания обучения я сохраняю trainloader.dataset.targets в переменной g и trainloader.dataset.data в переменной h. Я ожидаю, что e == g и f == h будут оба False с shuffle=True, но они снова дают True. Чего мне не хватает в определении DataLoader класса?

1 Ответ

1 голос
/ 09 апреля 2020

Я считаю, что данные, которые хранятся непосредственно в trainloader.dataset.data или .target, не будут перетасованы, данные перетасовываются только тогда, когда DataLoader вызывается как генератор или как итератор

You можете проверить это, выполнив следующее (iter (trainloader)) несколько раз без тасования и тасования, и они должны дать разные результаты

import torch
import torchvision

transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        ])
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)
dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = False,
                                         num_workers = 10)
target = dataLoader.dataset.targets


MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)

dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = True,
                                         num_workers = 10)

target_shuffled = dataLoader_shuffled.dataset.targets

print(target == target_shuffled)

_, target = next(iter(dataLoader));
_, target_shuffled = next(iter(dataLoader_shuffled))

print(target == target_shuffled)

Это даст:

tensor([True, True, True,  ..., True, True, True])
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False,  True, False,
        False, False, False, False, False, False, False, False])

Однако данные и метки, хранящиеся в data и target, представляют собой фиксированный список, и, поскольку вы пытаетесь получить к нему прямой доступ, они не будут перетасовываться.

...