Индексирование переменной tf в функции потерь - PullRequest
0 голосов
/ 21 февраля 2020

Я определяю пользовательскую функцию потерь в Tensorflow 1.9.0 (не могу обновить из-за ограничений проекта). У меня есть следующие переменные, полученные после разложения по собственным значениям:

# eigw.shape = (?, x)
# eigv.shape = (?, x, y)

Теперь я хочу вычислить argmax из eigw, так что

amax = tf.argmax(eigw, axis=1, output_type=tf.int32)
# amax.shape = (?,)

Я хочу индекс eigv со значениями, указанными в amax, например,

# result.shape = (?, y)

Как мне этого добиться? Я попытался получить к нему доступ напрямую, но при этом столкнулся с проблемой, что фигуры не имеют одинакового ранга. Кроме того, я попытался использовать tf.while_loop, но я новичок в tf, и, таким образом, мне это не удалось.

Какие еще варианты у меня есть? Как мне легче всего решить эту проблему?

Спасибо

Ответы [ 2 ]

1 голос
/ 21 февраля 2020

В вашем конкретном случае c вы можете использовать любую функцию TensorFlow, которая собирает максимальное значение по оси, а не по индексу.

max_value = tf.math.reduce_max(eigw, axis=1)

Вы можете увидеть любые другие параметры в документации. Поскольку больше нет документации по TF 1.9 на tesnorlfow.org, я мог бы найти r1.15, который все еще использовал графики stati c. https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/math/reduce_max

0 голосов
/ 02 марта 2020

Решение оказалось проще, чем предполагалось. После более глубокого изучения документации я обнаружил, что собственные векторы уже «упорядочены в неубывающем порядке». Так что мне просто нужно было взять последний собственный вектор. Спасибо за вклад всем.

...