тензорный эквивалент torch.gather - PullRequest
0 голосов
/ 01 сентября 2018

У меня есть тензор формы (16, 4096, 3). У меня есть другой тензор показателей формы (16, 32768, 3). Я пытаюсь собрать значения вдоль dim=1. Первоначально это было сделано в pytorch с использованием функции сбора , как показано ниже -

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)

Обратите внимание, что размер вывода b такой же, как у idx. Однако, когда я применяю gather функцию тензорного потока, я получаю совершенно другой вывод. Было обнаружено, что выходное измерение не соответствует, как показано ниже -

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)

Я тоже пытался использовать tf.gather_nd, но тщетно. Смотри ниже-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)

Почему я получаю различные формы тензоров? Я хочу получить тензор той же формы, который рассчитан по pytorch.

Другими словами, я хочу знать тензор потока, эквивалентный torch.gather.

1 Ответ

0 голосов
/ 04 октября 2018

Для 2D-случая, есть способ сделать это:

# a.shape (16L, 10L)
# idx.shape (16L,1)
idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
b = tf.gather_nd(a,idx)

Однако для случая ND этот метод может быть очень сложным

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...