Трансформеры huggingface: стратегия усечения в encode_plus - PullRequest
0 голосов
/ 06 августа 2020

encode_plus в библиотеке transformers huggingface позволяет обрезать входную последовательность. Релевантны два параметра: truncation и max_length. Я передаю парную входную последовательность в encode_plus, и мне нужно усечь входную последовательность просто «отрезанным» способом, т. Е. Если вся последовательность, состоящая из обоих входов text и text_pair, длиннее max_length он должен быть просто усечен соответственно справа.

Кажется, что ни одна из стратегий усечения не позволяет этого сделать, вместо этого longest_first удаляет токены из самой длинной последовательности (которая может быть либо text, либо text_pair, но не просто справа или в конце последовательности, например, если текст длиннее, чем text_pair, кажется, что сначала удаляются токены из текста), only_first и only_second удаляют токены только из первого или второго (следовательно, также не просто с конца), а do_not_truncate вообще не усекает. Или я неправильно понял это, и на самом деле longest_first может быть тем, что я ищу?

1 Ответ

1 голос
/ 07 августа 2020

Нет longest_first не то же самое, что cut from the right. Когда вы устанавливаете стратегию усечения на longest_first, токенизатор будет сравнивать длину как text, так и text_pair каждый раз, когда требуется удалить токен, и удаляет токен из самого длинного. Например, это может означать, что сначала будет вырезано 3 токена из text_pair, а остальные токены, которые необходимо вырезать, будут поочередно вырезаны из text и text_pair. Пример:

from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

seq1 = 'This is a long uninteresting text'
seq2 = 'What could be a second sequence to the uninteresting text'

print(len(tokenizer.tokenize(seq1)))
print(len(tokenizer.tokenize(seq2)))

print(tokenizer(seq1, seq2))

print(tokenizer(seq1, seq2, truncation= True, max_length = 15))
print(tokenizer.decode(tokenizer(seq1, seq2, truncation= True, max_length = 15)['input_ids']))

Вывод:

9
13
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 1037, 2117, 5537, 2000, 1996, 4895, 18447, 18702, 3436, 3793, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 102, 2054, 2071, 2022, 1037, 2117, 5537, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
[CLS] this is a long unint [SEP] what could be a second sequence [SEP]

Насколько я могу судить по вашему вопросу, вы на самом деле ищете only_second, потому что он срезается справа (это text_pair):

print(tokenizer(seq1, seq2, truncation= 'only_second', max_length = 15))

Вывод:

{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Он генерирует исключение, когда вы пытаетесь ввести text, длиннее указанного max_length. На мой взгляд, это правильно, потому что в данном случае это уже не ввод пары последовательностей.

На случай, если only_second не соответствует вашим требованиям, вы можете просто создать свою собственную стратегию усечения. Например, only_second вручную:


tok_seq1 = tokenizer.tokenize(seq1)
tok_seq2 = tokenizer.tokenize(seq2)

maxLengthSeq2 =  myMax_len - len(tok_seq1) - 3 #number of special tokens for bert sequence pair
if len(tok_seq2) >  maxLengthSeq2:
    tok_seq2 = tok_seq2[:maxLengthSeq2]

input_ids = [tokenizer.cls_token_id] 
input_ids += tokenizer.convert_tokens_to_ids(tok_seq1)
input_ids += [tokenizer.sep_token_id]

token_type_ids = [0]*len(input_ids)

input_ids += tokenizer.convert_tokens_to_ids(tok_seq2)
input_ids += [tokenizer.sep_token_id]
token_type_ids += [1]*(len(tok_seq2)+1) 


attention_mask = [1]*len(input_ids)
print(input_ids)
print(token_type_ids)
print(attention_mask)

Вывод:

[101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
...