Вы можете достичь вышеуказанного в матричной форме, без каких-либо циклов:
T --> [C, F, M]
T_1 --> transpose T to --> [C, F, M]
T_2 --> transpose T to --> [C, M, F]
d --> matmul (T_1, T_2) --> [C, M, M] --> transpose --> [M, M, C]
out --> multiply (d, N) : d -> [1, M, M, C], N -> [batch, 1, 1, C]
--> [batch, M, M, C] --> reduce_sum (axis=2) --> [batch, M, M]
--> add I
Рабочий код (соответствует вашему коду для batch=1
):
N_1 = tf.placeholder(tf.float32, [None, C])
reshape_T = tf.reshape(T, [C, F, M])
# reshape to do a batch matrix multiplication (C, F, M) and (C, M, F)
T_1 = tf.transpose(reshape_T, [0, 2, 1])
T_2 = tf.transpose(reshape_T, [0, 1, 2])
d = tf.transpose(tf.matmul(T_1,T_2), [2,1,0])
out = tf.reduce_sum(d[None,...]* tf.reshape(N_1, [-1, 1, 1, C]), axis=3) + I
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(out, {N_1: inp}))