Я хотел бы реализовать word2ve c в pytorch, а метод forward имеет следующую сигнатуру:
def forward(self, input_index_batch, output_indices_batch):
#input_index_batch - Tensor of ints, shape: (batch_size, )
#output_indices_batch - Tensor if ints, shape: (batch_size, num_negative_samples+1)
В подходе word2ve c вычисляется скалярное произведение между входными данными в U - матрица вложения, и каждый из выходов num_negative_samples + 1, которые соответствуют контекстным словам, причем первый - это положительное контекстное слово, а все остальные - отрицательные. Я решил проблему следующим образом:
U_batch = self.input.weight.T[input_index_batch]
for i, output_index in enumerate(output_indices_batch.T):
V_batch = self.output.weight[output_index]
dot_prod = torch.einsum('bi,bj->b', U_batch, V_batch)
if i == 0:
predictions[:, i] = torch.sigmoid(dot_prod)
else:
predictions[:, i] = 1 - torch.sigmoid(dot_prod)
return predictions
Это работает правильно, но довольно неэффективно и медленно. Каким будет более надежный способ реализовать это?
В методе инициализации self.input и self.output определяются как (10 - word2ve c dim):
self.input = nn.Linear(num_tokens, 10, bias=False)
self.output = nn.Linear(10, num_tokens, bias=False)