a=tf.reshape(tf.range(2*3*4),shape=(2,3,4))
# [[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]]
b=tf.constant([-1,2])
aa=tf.pad(a,[[0,0],[1,0],[0,0]])
bb=b+1
index=tf.stack([tf.range(tf.size(b)),bb],axis=-1)
res=tf.expand_dims(tf.gather_nd(aa, index),axis=1)
#[[[ 0, 0, 0, 0]],
#[[20, 21, 22, 23]]]
Когда индекс равен -1, нам нужны нули, такие как тензор. Таким образом, мы можем сначала добавить оригинальный тензор вдоль второй оси. Затем увеличьте индексы на 1. После этого, используя tf.gather_nd
, вы получите ответ.