Как сделать предложение сходства с X LNet? - PullRequest
1 голос
/ 22 марта 2020

Я хочу выполнить задачу сходства предложений и попробовал следующее:

from transformers import XLNetTokenizer, XLNetModel
import torch
import scipy
import torch.nn as nn
import torch.nn.functional as F

tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetModel.from_pretrained('xlnet-large-cased')

input_ids = torch.tensor(tokenizer.encode("Hello, my animal is cute", add_special_tokens=False)).unsqueeze(0)
outputs = model(input_ids)
last_hidden_states = outputs[0]

input_ids = torch.tensor(tokenizer.encode("I like your cat", add_special_tokens=False)).unsqueeze(0) 

outputs1 = model(input_ids)
last_hidden_states1 = outputs1[0]

cos = nn.CosineSimilarity(dim=1, eps=1e-6)
output = cos(last_hidden_states, last_hidden_states1)

Однако я получаю следующую ошибку:

RuntimeError: The size of tensor a (7) must match the size of tensor b (4) at non-singleton dimension 1

Может кто-нибудь сказать мне, что я делать неправильно? Есть ли лучший способ сделать это?

1 Ответ

0 голосов
/ 23 марта 2020

Есть несколько вещей, которые вы делаете неправильно.

  1. add_special_tokens должно быть установлено на True. Модель была обучена токену <sep> для разделения предложений и токену <cls> для классификации предложений. Неиспользование приводит к странному поведению из-за несоответствия данных теста поезда.

  2. outputs[0] дает вам первый элемент односоставного кортежа Python. Все модели из пакета Transformer возвращают кортежи, поэтому это кортеж из одного члена. Он содержит один вектор на каждый входной токен, включая специальные.

  3. В отличие от BERT, чей токен [CLS] является первым, здесь токен <cls> является самым последним (см. Трансформаторная документация ). Если вы хотите сравнить векторы классификации, вы должны взять последний вектор из последовательности, то есть outputs[0][:, -1].

В качестве альтернативы, вы можете сравнить среднее (среднее значение) встраивание, а не <cls> вложение токена. В этом случае вы можете просто сделать output[0].mean(1).

...