Как извлечь строки и столбцы из трехмерного массива в Tensorflow - PullRequest
0 голосов
/ 18 октября 2018

Я хотел сделать следующую операцию индексации для тензора TensorFlow.Какими должны быть эквивалентные операции в TensorFlow для получения b и c в качестве вывода?Хотя tf.gather_nd документация имеет несколько примеров, но я не смог сгенерировать эквивалентный indices тензор для получения этих результатов.

import tensorflow as tf
import numpy as np

a=np.arange(18).reshape((2,3,3))

idx=[2,0,1] #it can be any validing re-ordering index list

#These are the two numpy operations that I want to do in Tensorflow
b=a[:,idx,:]
c=a[:,:,idx] 

# TensorFlow operations

aT=tf.constant(a)
idxT=tf.constant(idx)

# what should be these two indices  
idx1T=tf.reshape(idxT, (3,1)) 
idx2T=tf.reshape(idxT, (1,1,3))

bT=tf.gather_nd(aT, idx1T ) #does not work
cT=tf.gather_nd(aT, idx2T)  #does not work

with tf.Session() as sess:
    b1,c1=sess.run([bT,cT])

print(np.allclose(b,b1))
print(np.allclose(c,c1))

Я не ограничен tf.gather_nd Любые другие предложения для выполнения тех же операций на GPUбудет полезно.

Редактировать: Я обновил вопрос для опечатки:

старое утверждение: c=a[:,idx],

Новое утверждение: c=a[:,:,idx] То, что я хотелдостижения были также переупорядочены столбцы.

1 Ответ

0 голосов
/ 18 октября 2018

Это можно сделать с помощью tf.gather, используя параметр axis:

import tensorflow as tf
import numpy as np

a = np.arange(18).reshape((2,3,3))
idx = [2,0,1]
b = a[:, idx, :]
c = a[:, :, idx]

aT = tf.constant(a)
idxT = tf.constant(idx)
bT = tf.gather(aT, idxT, axis=1)
cT = tf.gather(aT, idxT, axis=2)

with tf.Session() as sess:
    b1, c1=sess.run([bT, cT])

print(np.allclose(b, b1))
print(np.allclose(c, c1))

Вывод:

True
True
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...