Я могу представить две возможности, в зависимости от желаемого результата:
1) использовать аргумент batch_shape
, tf.reshape
и, возможно, tf.transpose
ident = tf.eye(5, batch_shape=(3, 4) # shape = (3, 4, 5, 5)
# switch axis 2 with axis 0
ident = tf.transpose(ident, (2, 1, 0, 3)) # shape = (5, 4, 3, 5)
2) используйте tf.expand_dims
или tf.reshape
в сочетании с tf.tile
:
ident = tf.eye(5) # shape = (5, 5)
ident = tf.reshape(ident, (5, 1, 1, 5)) # shape = (5, 1, 1, 5)
# or: ident = tf.expand_dims(tf.expand_dims(ident, axis=1), axis=1)
ident = tf.tile(ident, (1, 4, 3, 1)) # shape = (5, 4, 3, 5)