tf.argmax () для более чем одного индекса Tensorflow - PullRequest
0 голосов
/ 19 мая 2018

В Tensorflow функция tf.argmax () возвращает индекс наибольшего элемента в массиве.

Однако для задач классификации с несколькими метками функция, которая возвращает N наибольших элементов в массиве, будет иметь видочень кстати.

predicted_array: [0.4, 0.6, 0.7, 0.2, 0.9]
tf.something(predicted_array, N = 2): [2,4]

Чтобы затем сравнить его с истинным основанием одного горячо закодированного массива

one_hot_array: [0, 0, 1, 0, 1]
tf.something(one_hot_array, N = 2): [2,4]

Есть ли какая-либо функция, подобная этой?Или что-то похожее на это?

Спасибо за любую помощь

1 Ответ

0 голосов
/ 19 мая 2018

Да, есть.Это tf.nn.top_k (из здесь ).

Вы можете использовать его как tf.nn.top_k(predicted_array, k=2)

...