Как вычислить косинусоподобие в pytorch для всех строк в матрице относительно всех строк в другой матрице - PullRequest
0 голосов
/ 18 мая 2018

В pytorch, учитывая, что у меня есть 2 матрицы, как бы я вычислил косинусное сходство всех строк в каждой со всеми строками в другой.

Например,

При заданном входном значении =

matrix_1 = [a b] 
           [c d] 
matrix_2 = [e f] 
           [g h]

Я хотел бы, чтобы вывод был

output =

 [cosine_sim([a b] [e f])  cosine_sim([a b] [g h])]
 [cosine_sim([c d] [e f])  cosine_sim([c d] [g h])] 

В настоящее время я использую torch.nn.functional.cosine_simility (matrix_1, matrix_2), которыйвозвращает косинус строки только с соответствующей строкой в ​​другой матрице.

В моем примере у меня есть только 2 строки, но я бы хотел решение, которое работает для многих строк.Я бы даже хотел обработать случай, когда количество строк в каждой матрице отличается.

Я понимаю, что могу использовать расширение, однако я хочу сделать это без использования такого большого объема памяти.

1 Ответ

0 голосов
/ 19 мая 2018

Путем ручного вычисления сходства и игры с умножением матриц + транспонированием:

import torch
from scipy import spatial
import numpy as np

a = torch.randn(2, 2)
b = torch.randn(3, 2) # different row number, for the fun

# Given that cos_sim(u, v) = dot(u, v) / (norm(u) * norm(v))
#                          = dot(u / norm(u), v / norm(v))
# We fist normalize the rows, before computing their dot products via transposition:
a_norm = a / a.norm(dim=1)[:, None]
b_norm = b / b.norm(dim=1)[:, None]
res = torch.mm(a_norm, b_norm.transpose(0,1))
print(res)
#  0.9978 -0.9986 -0.9985
# -0.8629  0.9172  0.9172

# -------
# Let's verify with numpy/scipy if our computations are correct:
a_n = a.numpy()
b_n = b.numpy()
res_n = np.zeros((2, 3))
for i in range(2):
    for j in range(3):
        # cos_sim(u, v) = 1 - cos_dist(u, v)
        res_n[i, j] = 1 - spatial.distance.cosine(a_n[i], b_n[j])
print(res_n)
# [[ 0.9978022  -0.99855876 -0.99854881]
#  [-0.86285472  0.91716063  0.9172349 ]]
...