Получение наибольшего индекса от тензора - PullRequest
0 голосов
/ 29 декабря 2018

Мой CNN дал следующее (из model.predict()):

Tensor("input_1:0", shape=(?, 2, 26, 1), dtype=float32)
[9.9952221e-01 2.3613637e-04 1.9953270e-06 1.6922619e-05 2.2012556e-04
 2.4441533e-07 3.5276526e-07 7.4913805e-07 4.0657511e-07 8.7760031e-07]

Я хотел бы получить индекс наибольшего значения из этого массива.Прямо сейчас я попытался сделать это (x - массив выше):

result = x.index(max(x))

Вместо этого возникает ошибка о том, что этот тип данных не поддерживает .index?

1 Ответ

0 голосов
/ 29 декабря 2018

Вы можете просто использовать функцию np.argmax:

import numpy as np

preds = model.predict(test_data)
pred_class = np.argmax(preds, axis=-1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...