Я использую пользовательский набор данных PyTorch со следующим:
class ImageDataset(Dataset):
def __init__(self, input_dir, input_num, input_format, transform=None):
self.input_num = input_num
# etc
def __len__ (self):
return self.input_num
def __getitem__(self,idx):
targetnum = idx % self.input_num
# etc
Однако, когда я выполняю итерацию по этому набору данных, итерация возвращается к началу набора данных, а не заканчивается в конце набора данных.,Это фактически становится бесконечным циклом в итераторе, с оператором печати эпохи, никогда не выполняющимся для последующих эпох.
train_dataset=ImageDataset(input_dir = 'path/to/directory',
input_num = 300, input_format = "mask") # Size 300
num_epochs = 10
for epoch in range(num_epochs):
print("EPOCH " + str(epoch+1) + "\n")
num = 0
for data in train_dataset:
print(num, end=" ")
num += 1
# etc
Вывод на печать (... для значений между ними):
EPOCH 1
0 1 2 3 4 5 6 7 ... 298 299 300 301 302 303 304 305 ... 597 598 599 600 601 602 603 604 ...
Почему базовая итерация по набору данных продолжается после определенного __len__
набора данных, и как я могу гарантировать, что итерация по набору данных прекращается после достижения длины набора данных при использовании этого метода (или выполняется итерация по диапазону вручнуюдлины набора данных единственное решение)?
Спасибо.