Не уверен, что я понимаю вашу заботу о (ни одном) измерении.
Если я правильно понимаю, косинусное сходство между двумя матрицами одинаковой формы X
и Y
([batch, target_dim]
) - это просто матричное умножение X * Y^T
с некоторой нормализацией L2. Примечание X
будет вашим out1
, а Y
будет вашим out2
.
def Cosine_similarity(x, y, A):
"""Pair-wise Cosine similarity.
First `x` and `y` are transformed by A.
`X = xA^T` with shape [batch, target_dim],
`Y = yA^T` with shape [batch, target_dim].
Args:
x: shaped [batch, feature_dim].
y: shaped [batch, feature_dim].
A: shaped [targte_dim, feature_dim]. Transformation matrix to project
from `feature_dim` to `target_dim`.
Returns:
A cosine similarity matrix shaped [batch, batch]. The entry
at (i, j) is the cosine similarity value between vector `X[i, :]` and
`Y[j, :]` where `X`, `Y` are the transformed `x` and y` by `A`
respectively. In the other word, entry at (i, j) is the pair-wise
cosine similarity value between the i-th example of `x` and the j-th
example of `y`.
"""
x = tf.matmul(x, A, transpose_b=True)
y = tf.matmul(y, A, transpose_b=True)
x_norm = tf.nn.l2_normalize(x, axis=-1)
y_norm = tf.nn.l2_normalize(y, axis=-1)
y_norm_trans = tf.transpose(y_norm, [1, 0])
sim = tf.matmul(x_norm, y_norm_trans)
return sim
import numpy as np
feature_dim = 8
target_dim = 4
batch_size = 2
x = tf.placeholder(tf.float32, shape=(None, dim))
y = tf.placeholder(tf.float32, shape=(None, dim))
A = tf.placeholder(tf.float32, shape=(target_dim, feature_dim))
sim = Cosine_similarity(x, y, A)
with tf.Session() as sess:
x, y, sim = sess.run([x, y, sim], feed_dict={
x: np.ones((batch_size, feature_dim)),
y: np.random.rand(batch_size, feature_dim),
A: np.random.rand(target_dim, feature_dim)})
print 'x=\n', x
print 'y=\n', y
print 'sim=\n', sim
Результат:
x=
[[ 1. 1. 1. 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1. 1. 1. 1.]]
y=
[[ 0.01471654 0.76577073 0.97747731 0.06429122 0.91344446 0.47987637
0.09899797 0.773938 ]
[ 0.8555786 0.43403915 0.92445409 0.03393625 0.30154493 0.60895061
0.1233703 0.58597666]]
sim=
[[ 0.95917791 0.98181278]
[ 0.95917791 0.98181278]]