Я могу использовать tf.matmul(A, B)
для пакетного умножения матриц, когда:
A.shape == (..., a, b)
и
B.shape == (..., b, c)
, где ...
одинаковы.
Но я хочу дополнительное вещание:
Я ожидаю, что результатом будет a x b
пакетов умножения матриц между (2, d)
и (d, c)
.
Как это сделать?
Тестовый код:
import tensorflow as tf
import numpy as np
a = 3
b = 4
c = 5
d = 6
x_shape = (a, b, 2, d)
y_shape = (a, d, c)
z_shape = (a, b, 2, c)
x = np.random.uniform(0, 1, x_shape)
y = np.random.uniform(0, 1, y_shape)
z = np.empty(z_shape)
with tf.Session() as sess:
for i in range(b):
x_now = x[:, i, :, :]
z[:, i, :, :] = sess.run(
tf.matmul(x_now, y)
)
print(z)