Загрузка огромного пользовательского набора данных с помощью IterableDataset - PullRequest
0 голосов
/ 07 августа 2020

У меня есть огромный набор данных с функциями (input_id, input_mask, segment_id, label_id), сохраненными партиями по 64 в файле pickle. Я читаю этот файл, создаю TensorDataset и передаю загрузчику данных для обучения. Поскольку файл функций слишком велик для создания полного набора TensorDataset, я хочу преобразовать TensorDataset в IterableDataset, чтобы по одному пакету образцов можно было извлекать из файла функций за раз и передавать загрузчику данных. Но во время обучения я получаю следующую ошибку: TypeError: iter() returned non-iterator of type 'TensorDataset'

Ниже приведен класс пользовательского набора данных, который я написал:

class MyDataset(IterableDataset):

    def __init__(self,args):
        self.args=args
       
    def get_features(self,filename):
        with open(filename, "rb") as f:
            while True:
                try:
                    yield pickle.load(f)
                except EOFError:
                    break  
                    
    def process(self,args):
        if args.cached_features_file:
            cached_features_file = args.cached_features_file

        if os.path.exists(cached_features_file):
            features=self.get_features(cached_features_file)

        feat = next (features)
        li=list(feat)
        all_input_ids=torch.tensor([f.input_ids for f in li ], dtype=torch.long)
        all_input_mask= torch.tensor([f.input_mask for f in li ], dtype=torch.long)
        all_segment_ids= torch.tensor([f.segment_ids for f in li], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in li ], dtype=torch.long)
        
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        return dataset
      
    def __iter__(self):
        dataset=self.process(self.args)       
        return dataset

И я использую его так:

train_dataset=MyDataset(args)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)

Я понимаю, что TensorDataset - это стиль карты, требующий индекса, в то время как IterableDataset - это итеративный стиль, что является причиной ошибки. Даже если я верну список / кортеж тензоров функций вместо TensorDataset, я получаю аналогичную ошибку. Может кто-нибудь сказать мне, как правильно загрузить пакетный набор данных с помощью IterableDataset?

1 Ответ

0 голосов
/ 10 августа 2020

Я решил проблему, сохранив набор данных другим способом. Я сохранил функции в виде объектов словаря, постепенно добавленных в файл pickle, и просто прочитал их по одному и передал загрузчику данных для обработки. Пакетирование выполняется автоматически загрузчиком данных. Вот как теперь выглядит пользовательский класс:

class MyDataset(IterableDataset):

    def __init__(self,filename):
     
        self.filename=filename
        super().__init__()
                    
    def process(self,filename):
        with open(filename, "rb") as f:
            while True:
                try:
                    yield pickle.load(f)
                except EOFError:
                    break

    def __iter__(self):
        dataset=self.process(self.filename)          
        return dataset
...