индексация трехмерного пакетного тензора в тензорном потоке - PullRequest
0 голосов
/ 02 июня 2018

У меня есть 3D-тензор params с размером (Nx128x3), где N - размер пакета.

У меня также есть индекс тензора индекса = (Nx16), который содержит индикаторы во втором измерении A. Я хочу получить всю строку для данного индекса, чтобы результат был (Nx16x3).

В настоящее время я использую следующий код

gathered = tf.reshape(tf.gather(tf.reshape(params,-1,int(params.shape[2])]),tf.reshape(indices,[-1,])),[-1,int(indices.shape[1]),int(params.shape[2])])

Есть ли способ написать это с помощью collect_nd?

- Мой текущий полный код:

import tensorflow as tf
import numpy as np

N = 10
params = tf.constant(np.random.randn(N, 128, 3), dtype=tf.float32)
indices = tf.constant(np.random.randint(0, 128, [N,16]), dtype=tf.int32)

gathered = tf.reshape(tf.gather(tf.reshape(params,[-1,int(params.shape[2])]),tf.reshape(indices,[-1,])),[-1,int(indices.shape[1]),int(params.shape[2])])

with tf.Session() as sess:
    result = sess.run(gathered)
print('params: ',params.shape)
print('ind: ',indices.shape)
print('result: ',result.shape)
...