Первую операцию, которую вы хотите выполнить, можно просто выполнить следующим образом:
import tensorflow as tf
tf.random.set_seed(0)
params = tf.random.uniform([4, 64, 2])
idx = tf.random.uniform([4, 64], 0, 2, dtype=tf.int32)
out = tf.gather_nd(params, tf.expand_dims(idx, -1), batch_dims=2)
print(out.shape)
# (4, 64)
Для второй вам нужно построить полный многомерный индекс:
import tensorflow as tf
tf.random.set_seed(0)
params = tf.random.uniform([4, 64, 2])
idx = tf.random.uniform([4, 64], 0, 2, dtype=tf.int32)
update = tf.random.uniform([4, 64])
s = tf.shape(idx, out_type=idx.dtype)
ii, jj = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), indexing='ij')
idx_comp = tf.stack([ii, jj, idx], axis=-1)
out = tf.tensor_scatter_nd_update(params, idx_comp, update)
Хотя в В конкретном случае, когда ваше последнее измерение имеет два элемента, вы также можете использовать эту эквивалентную операцию:
update_t = tf.tile(tf.expand_dims(update, axis=-1), [1, 1, 2])
idx_t = tf.stack([idx, 1 - idx], axis=-1)
out = tf.where(tf.dtypes.cast(idx_t, tf.bool), params, update_t)