Различное поведение функции collect (), как в тензорном потоке и pytorch - PullRequest
0 голосов
/ 31 августа 2018

У меня есть тензор формы (16, 4096, 3). У меня есть другой тензор показателей формы (16, 32768, 3). Я пытаюсь собрать значения вдоль dim=1. Первоначально это было сделано в pytorch с использованием функции gather, как показано ниже -

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)

Обратите внимание, что размер вывода b совпадает с размером idx. Тем не менее, когда я применяю gather функцию тензорного потока, я получаю совершенно другой вывод. Было обнаружено, что выходное измерение не соответствует, как показано ниже -

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)

Я тоже пытался использовать tf.gather_nd, но тщетно. Смотри ниже-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)

Почему я получаю различные формы тензоров? Я хочу получить тензор той же формы, который рассчитывается по pytorch.

Как добиться того же результата, который дает pytorch?

1 Ответ

0 голосов
/ 01 сентября 2018

Если я вас правильно понимаю, то tf.gather_nd - это то, что вы ищете. Если нет, пожалуйста, будьте немного яснее.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...