Предположим, у меня есть партия, состоящая из двух тензоров, а тензоры в патче имеют размер 3.
data = [[0.3, 0.5, 0.7], [-0.3, -0.5, -0.7]]
Теперь я хочу извлечь из каждого тензора в патче одну элементную базу на индекс:
index = [0, 2]
Поэтому выходной результат должен быть
out = [0.3, -0.7] # Get index 0 from the first tensor in the batch and index 2 from the second tensor in the batch.
Конечно, это должно быть распространено на большие размеры партии. Размер index
равен размеру пакета.
Я пытался применить tf.gather
и tf.gather_nd
, но не получил желаемых результатов.
Например, код ниже выведите 0.7
и , а не желаемый результат, указанный выше:
data = [[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]]
index = [0, 2]
out = tf.gather_nd(data, index)
print(out.numpy())