Сортировка 2D Tensorflow для всех осей - PullRequest
0 голосов
/ 11 мая 2018

Я хочу отсортировать 2D тензор от минимальной до максимальной пары элементов. Я попытался сделать это, следуя этому сообщению . Может сортировать, но не идеальная сортировка.

Это мой полный код:

a = [[0.3,0.32],[0.1,5.2],[1.6,0.4],[0.3,0.1],[5.2,2.6],[0.5,1.15]]
a= tf.convert_to_tensor(a)

length = tf.size(a)/2
point = tf.gather(a, tf.nn.top_k(-a[:,0], k=length).indices)

и это вывод:

[[0.1  5.2 ]
 [0.3  0.32]
 [0.3  0.1 ]
 [0.5  1.15]
 [1.6  0.4 ]
 [5.2  2.6 ]]

Предполагается, что выходные данные будут отсортированы совершенно так (посмотрите на 0.3 тензор):

[[0.1  5.2 ]
[0.3  0.1 ]
[0.3  0.32]
[0.5  1.15]
[1.6  0.4 ]
[5.2  2.6 ]]

Кто-нибудь может помочь? Спасибо.

1 Ответ

0 голосов
/ 11 мая 2018

Сортировка сначала по второму столбцу, затем по первому. Этот код (проверено):

import tensorflow as tf

a = [[0.3,0.32],[0.1,5.2],[1.6,0.4],[0.3,0.1],[5.2,2.6],[0.5,1.15]]
a = tf.convert_to_tensor(a)

def sort( a, col ):
    return tf.gather(a, tf.nn.top_k( -a[ :, col ], k=a.get_shape()[ 0 ].value ).indices )

point = sort( sort( a, 1 ), 0 )

with tf.Session() as sess:
    print( sess.run( point ) )

выведет:

[[0.1 5.2]
[0,3 0,1]
[0,3 0,32]
[0,5 1,15]
[1,6 0,4]
[5.2 2.6]]

по желанию.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...