Объект DataLoader не поддерживает индексацию - PullRequest
1 голос
/ 01 июля 2019

Я загрузил набор данных ImageNet через этот API Pytorch, установив download = True.Но я не могу выполнить итерацию по загрузчику данных.

Ошибка говорит: «Объект DataLoader не поддерживает индексирование»

trainset = torch.utils.data.DataLoader(
    datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
                      download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)

Я попробовал простой подход, я просто попытался выполнить следующее,

trainloader[0]

В корневом каталоге используется шаблон

root/  
    train/  
          n01440764/
          n01443537/ 
                   n01443537_2.jpg

В документах на официальном сайте больше ничего не сказано.https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet

Что я делаю не так?

Ответы [ 3 ]

1 голос
/ 02 июля 2019

Ну, ответ довольно прост (кроме ошибки, упомянутой в другом ответе).

DataLoader не имеет __getitem__ метода (см. в исходном коде для себя).

Используется для итерации, а не произвольного доступа к данным (или пакетам данных). Если вы хотите получить доступ к определенному элементу, вы должны использовать torch.utils.data.Dataset, в вашем случае:

trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]

Получение партии

Если вы хотите получить пакет, вы можете перебрать его и потом разбить:

for batch in dataloader:
    print(batch) # or anything else you want to do
    break

DataLoader создает случайные индексы по умолчанию или заданным способом (см. samplers ), следовательно, нет __getitem__, так как это не имело бы смысла для этого объекта.

Вы также можете наследовать от DataLoader и создавать свою собственную функцию __getitem__, делающую то, что вы хотите (хотя и более сложную).

Полный пример

# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)

for batch in trainloader:
    print(batch)
    break

Вверху должна быть напечатана первая партия, которая находится внутри.

0 голосов
/ 02 июля 2019

Раствор

input_transform = standard_transforms.Compose([
    transforms.Resize((255,255)), # to Make sure all the 
    transforms.CenterCrop(224),   # imgs are at the same size 
    transforms.ToTensor()
])  


# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
                             split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)


for batch_idx, data in enumerate(trainloader, 0):
    x, y = data 
    break
0 голосов
/ 01 июля 2019

Входной набор данных для torch.utils.data.DataLoader() должен иметь тип torch.utils.data.Dataset, а не torch.utils.data.DataLoader, что вы делаете в приведенном выше коде.

Итак, ваш код должен быть:

trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', 
                                          split='train', 
                                          download=False)

trainloader = torch.utils.data.DataLoader(trainset, 
                                          batch_size=1, 
                                          shuffle=False, 
                                          num_workers=1)

Подробнее см. Официальную документацию на факел здесь .

...