Вы можете сделать следующее. По сути, сначала вы сглаживаете все измерения, кроме последнего из y
, и создаете индекс для сглаживания y
. Вы выполняете индексацию, а затем изменяете форму до правильной формы.
y = tf.constant(np.random.normal(size=(5,10,20,3)), dtype='float32')
y_index = tf.constant(np.random.randint(3, size=(5,10,20)), dtype='int32')
# Creating an index like [(0,y_index[0]), (1, y_index[1]), ...]
inds = tf.stack([tf.range(5*10*20),tf.reshape(y_index,[-1])],axis=1)
y_slice = tf.reshape(tf.gather_nd(tf.reshape(y,[-1,3]),inds),[5,10,20])