Как создать torchtext.data.TabularDataset прямо из списка или продиктовать - PullRequest
0 голосов
/ 29 октября 2018

torchtext.data.TabularDataset можно создать из файла TSV / JSON / CSV, а затем использовать его для построения словаря из Glove, FastText или любых других вложений. Но мое требование - создать torchtext.data.TabularDataset напрямую, либо из list, либо из dict.

Текущая реализация кода путем чтения файлов TSV

self.RAW = data.RawField()
self.TEXT = data.Field(batch_first=True)
self.LABEL = data.Field(sequential=False, unk_token=None)


self.train, self.dev, self.test = data.TabularDataset.splits(
    path='.data/quora',
    train='train.tsv',
    validation='dev.tsv',
    test='test.tsv',
    format='tsv',
    fields=[('label', self.LABEL),
            ('q1', self.TEXT),
            ('q2', self.TEXT),
            ('id', self.RAW)])


self.TEXT.build_vocab(self.train, self.dev, self.test, vectors=GloVe(name='840B', dim=300))
self.LABEL.build_vocab(self.train)


sort_key = lambda x: data.interleave_keys(len(x.q1), len(x.q2))


self.train_iter, self.dev_iter, self.test_iter = \
    data.BucketIterator.splits((self.train, self.dev, self.test),
                               batch_sizes=[args.batch_size] * 3,
                               device=args.gpu,
                               sort_key=sort_key)

Это текущий рабочий код для чтения данных из файла. Поэтому, чтобы создать набор данных непосредственно из List / Dict, я попробовал встроенные функции, такие как Examples.fromDict или examples.fromList, но затем при переходе к последнему циклу for выдает ошибку, что AttributeError: 'BucketIterator' object has no attribute 'q1'

1 Ответ

0 голосов
/ 01 ноября 2018

Мне потребовалось написать собственный класс, унаследованный от класса Dataset и с небольшими изменениями в классе torchtext.data.TabularDataset.

class TabularDataset_From_List(data.Dataset):

    def __init__(self, input_list, format, fields, skip_header=False, **kwargs):
        make_example = {
            'json': Example.fromJSON, 'dict': Example.fromdict,
            'tsv': Example.fromTSV, 'csv': Example.fromCSV}[format.lower()]

        examples = [make_example(item, fields) for item in input_list]

        if make_example in (Example.fromdict, Example.fromJSON):
            fields, field_dict = [], fields
            for field in field_dict.values():
                if isinstance(field, list):
                    fields.extend(field)
                else:
                    fields.append(field)

        super(TabularDataset_From_List, self).__init__(examples, fields, **kwargs)

    @classmethod
    def splits(cls, path=None, root='.data', train=None, validation=None,
               test=None, **kwargs):
        if path is None:
            path = cls.download(root)
        train_data = None if train is None else cls(
            train, **kwargs)
        val_data = None if validation is None else cls(
            validation, **kwargs)
        test_data = None if test is None else cls(
            test, **kwargs)
        return tuple(d for d in (train_data, val_data, test_data)
                     if d is not None)
...