Используя transpose
и reshape
, вы можете достичь того же:
a : [batch, 1152, 8] --> reshape --> [batch, 1, 1, 1152, 8]
b : [16,8,1152,10] --> transpose --> [16, 10, 1152, 8]
--> expand_dims --> [1, 16, 10, 1152, 8]
multiply (a, b) --> [batch, 16, 10, 1152, 8]
reduce_sum axis 4 --> [batch, 16, 10, 1152]
Код:
#inputs
import numpy.testing as npt
x = np.random.normal(size=(5,1152,8))
y = np.random.normal(size=(16, 8, 1152, 10))
a = tf.placeholder(tf.float32,shape=(None, 1152, 8))
b = tf.constant(y, tf.float32)
out = tf.reduce_sum(tf.expand_dims(tf.transpose(b,[0, 3, 2, 1]),0)
* tf.reshape(a,[-1,1,1,tf.shape(a)[1], tf.shape(a)[2]]), axis=4)
out = tf.transpose(out, [0,1,3,2])
out_ein = tf.einsum('ijkl,bkj->bikl', b, a)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
o = sess.run(out, {a: x})
e = sess.run(out_ein, {a: x})
npt.assert_almost_equal(o, e, decimal=5)
#almost the same