Tensorflow - сортировка тензора с помощью специального компаратора - PullRequest
0 голосов
/ 03 мая 2018

Как я могу отсортировать тензор Tensorflow целых чисел с формой [n, 2] в соответствии с пользовательской функцией сравнения, используя только операции Tensorflow?

Допустим, в моем тензоре есть две записи [x1, y1] и [x2, y2]. Я хочу отсортировать тензор так, чтобы записи были упорядочены по условию x1 * y2> x2 * y1.

Ответы [ 2 ]

0 голосов
/ 02 апреля 2019

В качестве альтернативы top_k ответу Питера Шолдана , после 1.13 есть tf.argsort.

Для <1.13 используйте </p>

tf.nn.top_k( s, tf.shape(x)[0] )

если не удается получить форму в статическом графике.

0 голосов
/ 03 мая 2018

Предположим, что вы можете создать метрику (если нет, см. Общий случай ниже) для ваших элементов (здесь, с перестановкой неравенства в x1 / y1> x2 / y2 , поэтому метрика будет x / y и полагаться на TensorFlow для получения inf (как в бесконечности) для деления на ноль), используйте tf.nn.top_k() как этот код (проверено):

import tensorflow as tf

x = tf.constant( [ [1,2], [3,4], [1,3], [2,5] ] ) # some example numbers

s = tf.truediv( x[ ..., 0 ], x[ ..., 1 ] ) # your sort condition
val, idx = tf.nn.top_k( s, x.get_shape()[ 0 ].value )
x_sorted = tf.gather( x, idx )

with tf.Session() as sess:
    print( sess.run( x_sorted ) )

Выходы:

[[3 4]
[1 2]
[2 5]
[1 3]]


Если вы не можете или не можете легко создать метрику, тогда все еще существует предположение, что отношение дает вам правильного порядка . (В противном случае результаты не определены.) В этом случае вы строите матрицу сравнения для всего набора и упорядочиваете элементы по сумме строк (т. Е. Сколько других элементов больше); это, конечно, квадратичное число элементов для сортировки. Этот код (проверено):

import tensorflow as tf

x = tf.constant( [ [1,2], [3,4], [1,3], [2,5] ] ) # some example numbers

x1, y1 = x[ ..., 0 ][ None, ... ], x[ ..., 1 ][ None, ... ] # expanding dims into cols
x2, y2 = x[ ..., 0, None ],        x[ ..., 1, None ] # expanding into rows
r = tf.cast( tf.less( x1 * y2, x2 * y1 ), tf.int32 ) # your sort condition, with implicit broadcasting
s = tf.reduce_sum( r, axis = 1 ) # how many other elements are greater

val, idx = tf.nn.top_k( s, s.get_shape()[ 0 ].value )
x_sorted = tf.gather( x, idx )

with tf.Session() as sess:
    print( sess.run( x_sorted ) )

Выходы:

[[3 4]
[1 2]
[2 5]
[1 3]]

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...