Я определяю пользовательскую функцию потерь в Tensorflow 1.9.0 (не могу обновить из-за ограничений проекта). У меня есть следующие переменные, полученные после разложения по собственным значениям:
# eigw.shape = (?, x)
# eigv.shape = (?, x, y)
Теперь я хочу вычислить argmax
из eigw
, так что
amax = tf.argmax(eigw, axis=1, output_type=tf.int32)
# amax.shape = (?,)
Я хочу индекс eigv
со значениями, указанными в amax
, например,
# result.shape = (?, y)
Как мне этого добиться? Я попытался получить к нему доступ напрямую, но при этом столкнулся с проблемой, что фигуры не имеют одинакового ранга. Кроме того, я попытался использовать tf.while_loop
, но я новичок в tf, и, таким образом, мне это не удалось.
Какие еще варианты у меня есть? Как мне легче всего решить эту проблему?
Спасибо