Я использую тензор потока tf.gather
, чтобы получить элементы из многомерного массива, подобного этому:
import tensorflow as tf
indices = tf.constant([0, 1, 1])
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
result = tf.gather(x, indices, axis=1)
with tf.Session() as sess:
selection = sess.run(result)
print(selection)
, что приводит к:
[[1 2 2]
[4 5 5]
[7 8 8]]
, хотя я хочу:
[1
5
8]
как я могу использовать tf.gather
, чтобы применить отдельные индексы к указанной оси?(Тот же результат, что и в обходном пути, указанном в этом ответе: https://stackoverflow.com/a/41845855/9763766)