Как вычислить средневзвешенное значение тензора A вдоль оси с весами, заданными тензором B в тензорном потоке? - PullRequest
1 голос
/ 28 марта 2019

Я пытаюсь применить средневзвешенную схему к выходу RNN.
Выход RNN представлен тензором A, имеющим размерность (a,b,c).
Я могу просто взять tf.reduce_mean(A,axis=1), чтобы получить тензор C, имеющий размерность (a,c).

Однако я хочу сделать «средневзвешенное значение» тензора A вдоль axis = 1.
Веса указаны в матрице B, имеющей размерность (d,b).

Для d = 1 я могу сделать tf.tensordot(A,B,[1,1]), чтобы получить результат измерения (a,c).
Теперь для d=a я не могу вычислить средневзвешенное значение.

Может кто-нибудь предложить решение?

Ответы [ 2 ]

1 голос
/ 30 марта 2019

Поскольку B уже нормализовано, ответ

tf.reduce_sum(A * B[:, :, None], axis=1)

Индексирование с помощью None добавляет новое измерение, поведение, унаследованное от numpy. B[:,:, None] добавляет последнее измерение, чтобы результат имел форму (a, b, 1). Вы можете добиться того же с помощью tf.expand_dims, чье имя может иметь для вас больше смысла.

A имеет форму (a, b, c), тогда как B[:, :, None] имеет форму (a, b, 1). Когда они умножены, расширенный B будет также иметь форму (a, b, c), а последнее измерение будет c копий того же значения. Это называется вещание .

Из-за того, как вещание работает, тот же ответ также работает, если B имеет форму (1, b).

1 голос
/ 28 марта 2019

Я не совсем понимаю, почему B должен иметь размеры (d,b). Если B содержит веса, чтобы сделать средневзвешенное значение A только для одного измерения, B должен быть только вектором (b,), а не матрицей.

Если B - вектор, вы можете сделать:

C = tf.tensordot(A,B,[1,0]) для получения вектора C формы * (1013 *), который содержит средневзвешенное значение A по axis=1 с использованием весов, указанных в B.

Обновление:

Вы можете сделать что-то вроде:

A = A*B[:,:,None] 

, который выполняет поэлементное умножение A и B, где B хранит веса, присвоенные каждому элементу в A. Тогда:

C = tf.reduce_mean(A,axis=1)

получит средневзвешенное значение, поскольку каждый элемент в A умножен на его вес.

...