Вы можете сделать это с помощью tf.tensor_scatter_nd_update
следующим образом:
import tensorflow as tf
tensor_a = ... # Some 3D tensor
idx_to_replace = 0
new_value = 1
s = tf.shape(tensor_a)
i1, i2 = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), indexing='ij')
i3 = idx_to_replace * tf.ones_like(i1)
idx = tf.stack([i1, i2, i3], axis=-1)
updates = new_value * tf.ones_like(i1, dtype=tensor_a.dtype)
result = tf.tensor_scatter_nd_update(tensor_a, idx, updates)
Хотя это не работает с отрицательными индексами, вам нужно сделать его положительным, например с помощью:
idx_to_replace = tf.cond(tf.less(idx_to_replace, 0),
lambda: idx_to_replace + s[-1],
lambda: idx_to_replace)
Однако, чтобы заменить первый индекс последнего измерения на единицы, вам может быть проще и быстрее сделать что-то вроде этого:
result = tf.concat([tf.ones_like(tensor_a[..., :1]), tensor_a[..., 1:]], axis=-1)
Аналогично для последнего измерения :
result = tf.concat([tensor_a[..., :-1], tf.ones_like(tensor_a[..., -1:])], axis=-1)