Я пытаюсь использовать предварительно обученный токенизатор из HuggingFace Transformer-XL при обучении моей пользовательской модели трансформатора-XL на WikiText2, и у меня возникают проблемы с работой BPTTIterator из Torchtext. Ошибка возникает в самой последней строке next(iter(train_iter))
Ниже мой код:
# Import packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AdamW, WarmupLinearSchedule
from transformers import TransfoXLConfig, TransfoXLTokenizer, TransfoXLModel,TransfoXLLMHeadModel
import torchtext
import torchtext.data.utils
from torchtext.data import Field, BPTTIterator
import random
import time
# set hyperparameters for this experiment
bptt = 30
batch_size = 64
lr = 0.01 # learning rate
# load the pretrained tokenizer
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103', do_lower_case=True)
# for huggingface - torchtext integration
tokenizer.mask_token = 'maskTok'
tokenizer.pad_token = '<pad>'
tokenizer.eos_token = '<eos>'
tokenizer.unk_token = '<unk>'
tokenizer.bos_token = '<sos>'
pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
eos_index = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
unk_index = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
mask_index = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
bos_index = tokenizer.convert_tokens_to_ids(tokenizer.bos_token)
# for huggingface - torchtext integration
tokenizer.mask_token = 'maskTok'
tokenizer.pad_token = '<pad>'
tokenizer.eos_token = '<eos>'
tokenizer.unk_token = '<unk>'
tokenizer.bos_token = '<sos>'
pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
eos_index = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
unk_index = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
mask_index = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
bos_index = tokenizer.convert_tokens_to_ids(tokenizer.bos_token)
# load WikiText-2 dataset and split it into train and test set
train_Wiki2, val_Wiki2, test_Wiki2 = torchtext.datasets.WikiText2.splits(TEXT)
# extract total number of tokens in the vocabulary
ntokens = tokenizer.vocab_size
# define transformer-XL configuration.
transfoXLconfig = TransfoXLConfig(vocab_size_or_config_json_file = ntokens,
cutoffs = [20000, 40000, 200000],
d_model = 1024,
d_embed = 1024,
n_head = 16,
d_head = 64,
n_layer = 5,
dropout = 0.1,
attn_type = 0,
output_hidden_states = True,
output_attentions = True)
model = TransfoXLLMHeadModel(config = transfoXLconfig)
model.resize_token_embeddings(len(tokenizer))
train_iter, test_iter = BPTTIterator.splits(
(train_Wiki2, test_Wiki2),
batch_size = batch_size,
bptt_len= bptt,
shuffle = False,
repeat=False)
# error occurs here; the error message is:
# File "/Users/jin-dominique/anaconda3/lib/python3.7/site-packages/torchtext/data/field.py",
# line 359, in numericalize
# var = torch.tensor(arr, dtype=self.dtype, device=device)
# "TypeError: an integer is required (got type str)"
train = next(iter(train_iter))
test = next(iter(test_iter))
Как я могу исправить свою ошибку? Спасибо,