У меня есть огромный набор данных с функциями (input_id, input_mask, segment_id, label_id), сохраненными партиями по 64 в файле pickle. Я читаю этот файл, создаю TensorDataset и передаю загрузчику данных для обучения. Поскольку файл функций слишком велик для создания полного набора TensorDataset, я хочу преобразовать TensorDataset в IterableDataset, чтобы по одному пакету образцов можно было извлекать из файла функций за раз и передавать загрузчику данных. Но во время обучения я получаю следующую ошибку: TypeError: iter() returned non-iterator of type 'TensorDataset'
Ниже приведен класс пользовательского набора данных, который я написал:
class MyDataset(IterableDataset):
def __init__(self,args):
self.args=args
def get_features(self,filename):
with open(filename, "rb") as f:
while True:
try:
yield pickle.load(f)
except EOFError:
break
def process(self,args):
if args.cached_features_file:
cached_features_file = args.cached_features_file
if os.path.exists(cached_features_file):
features=self.get_features(cached_features_file)
feat = next (features)
li=list(feat)
all_input_ids=torch.tensor([f.input_ids for f in li ], dtype=torch.long)
all_input_mask= torch.tensor([f.input_mask for f in li ], dtype=torch.long)
all_segment_ids= torch.tensor([f.segment_ids for f in li], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in li ], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
return dataset
def __iter__(self):
dataset=self.process(self.args)
return dataset
И я использую его так:
train_dataset=MyDataset(args)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
Я понимаю, что TensorDataset - это стиль карты, требующий индекса, в то время как IterableDataset - это итеративный стиль, что является причиной ошибки. Даже если я верну список / кортеж тензоров функций вместо TensorDataset, я получаю аналогичную ошибку. Может кто-нибудь сказать мне, как правильно загрузить пакетный набор данных с помощью IterableDataset?