Как объект набора данных pytorch узнает, достиг ли он конца при использовании в for l oop? - PullRequest
2 голосов
/ 06 августа 2020

Я пишу собственный набор данных pytorch. В __init__ объект набора данных загружает файл, содержащий определенные данные. Но в моей программе я использую только wi sh для доступа к части данных (для достижения обучающего / действительного сокращения, если это помогает). Первоначально я думал, что это поведение контролируется переопределением __len__, но оказалось, что изменение __len__ не помогает. Вот простой пример:

from torch.utils.data import Dataset, DataLoader
import torch

class NewDS(Dataset):
    def __init__(self):
        self.data = torch.randn(10,2) # suppose there are 10 items in the data file
    
    def __len__(self):
        return len(self.data)-5 # But I only want to access the first 5 items
        
    def __getitem__(self, index):
        return self.data[index]

ds = NewDS()
for i, x in enumerate(ds):
    print(i)

Выходные данные - от 0 до 9, в то время как желаемое поведение будет от 0 до 4.

Как этот объект набора данных узнает, что перечисление попало конец при использовании в a для l oop вот так? Также приветствуется любой другой метод достижения аналогичного эффекта.

Ответы [ 2 ]

2 голосов
/ 06 августа 2020

Вы создаете собственный загрузчик данных с классом Dataset, а перечисляете его с помощью l oop. Это не так, как это работает. Для перечисления вам необходимо передать класс Dataset в DataLoader. ваш код будет отлично работать, например,

from torch.utils.data import Dataset, DataLoader
import torch

class NewDS(Dataset):
    def __init__(self):
        self.data = torch.randn(10,2) # suppose there are 10 items in the data file
    
    def __len__(self):
        return len(self.data)-5 # But I only want to access the first 5 items
        
    def __getitem__(self, index):
        return self.data[index]

ds = NewDS()
for i, x in range(len(ds)): #if you do dont want to use DataLoader, then dont use enumerate
    print(i, ds[i])
#output 
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])

dl = DataLoader(ds, batch_size=1) # pass the ds object to DataLoader 

for i, x in enumerate(dl): # now you can use enumarate
    print(i, x)
#output
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])

Более подробную информацию можно прочитать в этом официальном руководстве по pytorch .

1 голос
/ 06 августа 2020

Вы можете использовать torch.utils.data.Subset, чтобы получить подмножество ваших данных

top_five = torch.utils.data.Subset(ds, indices=range(5))  # Get first five items
for i, x in enumerate(top_five):
    print(i)
0
1
2
3
4

enumerate в l oop вернет элемент, пока не получит исключение StopIteration.

len(ds)         # Returned modified length
5

# `enumerate` will call `next` method on iterable each time in loop.
#  and When no more data available a StopIteration exception is raised instead.
iter_ds = iter(ds)
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))

print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate in iterable

Вывод:

tensor([-1.5952, -0.0826])
tensor([-2.2254,  0.2461])
tensor([-0.8268,  0.1956])
tensor([ 0.3157, -0.3403])
tensor([0.8971, 1.1255])
tensor([0.3922, 1.3184])
tensor([-0.4311, -0.8898])
tensor([ 0.1128, -0.5708])
tensor([-0.5403, -0.9036])
tensor([0.6550, 1.6777])

---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-99-7a9910e027c3> in <module>
     10 print(next(iter_ds))
     11 
---> 12 print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate

StopIteration: 
...