Аргументы ключевых слов в функции вызова BERT - PullRequest
2 голосов
/ 25 марта 2020

В библиотеке BERT HuggingFace TensorFlow 2.0 документация гласит, что:

Модели TF 2.0 принимают в качестве входных данных два формата:

  • имея все входные данные в качестве аргументов ключевого слова (например, модели PyTorch), или

  • имея все входные данные в виде списка, кортежа или диктата в первых позиционных аргументах.

Я пытаюсь использовать первый из этих двух для вызова созданной мной модели BERT:

from transformers import BertTokenizer, TFBertModel
import tensorflow as tf

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = TFBertModel.from_pretrained('bert-base-uncased')

text = ['This is a sentence', 
        'The sky is blue and the grass is green', 
        'More words are here']
labels = [0, 1, 0]
tokenized_text = bert_tokenizer.batch_encode_plus(batch_text_or_text_pairs=text,
                                                  pad_to_max_length=True,
                                                  return_tensors='tf')
dataset = tf.data.Dataset.from_tensor_slices((tokenized_text['input_ids'],
                                              tokenized_text['attention_mask'],
                                              tokenized_text['token_type_ids'],
                                              tf.constant(labels))).batch(3)
sample = next(iter(dataset))

result1 = bert_model(inputs=(sample[0], sample[1], sample[2]))  # works fine
result2 = bert_model(inputs={'input_ids': sample[0], 
                             'attention_mask': sample[1], 
                             'token_type_ids': sample[2]})  # also fine
result3 = bert_model(input_ids=sample[0], 
                     attention_mask=sample[1], 
                     token_type_ids=sample[2])  # raises an error

Но когда я выполняю последнюю строку, я получаю сообщение об ошибке:

TypeError: __call__() missing 1 required positional argument: 'inputs'

Может кто-нибудь объяснить, как правильно использовать стиль ввода аргументов ключевого слова?

1 Ответ

1 голос
/ 25 марта 2020

Кажется, что внутренне они интерпретируют inputs как input_ids, если вы не указали в качестве первого аргумента более одного тензора. Вы можете видеть это в TFBertModel, а затем ищете TFBertMainLayer call функцию.

Для меня я получаю точно такие же результаты, как result1 и result2 если я сделаю следующее:

result3 = bert_model(inputs=sample[0], 
                     attention_mask=sample[1], 
                     token_type_ids=sample[2])

В качестве альтернативы, вы также можете просто сбросить inputs=, тоже работает.

...