Это должно помочь вам (переходя к более мелким фигурам для ускорения вычислений):
X = np.ones(((5, 2, 3, 7)))
W = np.ones((X.shape[3], 10))
X_reshaped = tf.reshape(X, [-1, X.shape[3]])
# Shape: (30, 7)
y = tf.matmul(X_reshaped, W)
# Shape: (30, 10)
y_reshaped = tf.reshape(y, [-1, X.shape[1], X.shape[2], 10])
# Shape: (5, 2, 3, 10)