Тензор потока tf.gather с параметром оси - PullRequest
0 голосов
/ 22 мая 2018

Я использую тензор потока tf.gather, чтобы получить элементы из многомерного массива, подобного этому:

import tensorflow as tf

indices = tf.constant([0, 1, 1])
x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])

result = tf.gather(x, indices, axis=1)

with tf.Session() as sess:
    selection = sess.run(result)
    print(selection)

, что приводит к:

[[1 2 2]
 [4 5 5]
 [7 8 8]]

, хотя я хочу:

[1
 5
 8]

как я могу использовать tf.gather, чтобы применить отдельные индексы к указанной оси?(Тот же результат, что и в обходном пути, указанном в этом ответе: https://stackoverflow.com/a/41845855/9763766)

1 Ответ

0 голосов
/ 22 мая 2018

Вам необходимо преобразовать indices в full indices, используя gather_nd.Может быть достигнуто путем:

result = tf.squeeze(tf.gather_nd(x,tf.stack([tf.range(indices.shape[0])[...,tf.newaxis], indices[...,tf.newaxis]], axis=2)))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...