Как работает аргумент masked_lm_labels в BertForMaskedLM? - PullRequest
1 голос
/ 28 апреля 2020
from transformers import BertTokenizer, BertForMaskedLM
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)

loss, prediction_scores = outputs[:2] 

Этот код со страницы трансформеров обнимающего лица. https://huggingface.co/transformers/model_doc/bert.html#bertformaskedlm

Я не могу понять аргумент masked_lm_labels=input_ids в model. Как это работает? Означает ли это, что он будет автоматически маскировать часть текста при передаче input_ids?

1 Ответ

0 голосов
/ 28 апреля 2020

Первый аргумент - это маскированный вход, аргумент masked_lm_labels - требуемый.

input_ids должен быть замаскирован. В общем, это зависит от вас, как вы делаете маскировку. В исходном BERT они выбирают 15% токенов и следующие с ними, либо

  • Использование [MASK] токенов; или
  • Используйте случайный токен; или
  • Сохранить исходный токен без изменений.

Это изменяет входные данные, поэтому вам нужно указать вашей модели, какой исходный немаскированный входной сигнал является аргументом masked_lm_labels. Также обратите внимание, что вы не хотите вычислять потери только для токенов, которые были фактически выбраны для маскировки. Остальные токены должны быть заменены индексом -100.

Для получения более подробной информации, см. Документацию .

...