Как я могу использовать np argmax в массиве списков? - PullRequest
1 голос
/ 18 апреля 2019

У меня есть данные y_hat, которые выглядят так:

[[0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 ...
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]]

Я хочу получить argmax каждой строки, чтобы у меня был такой вектор:

[[3]
 [8]
 [8]
 ...
 [5]
 [1]
 [7]]

Если я просто наберу np.argmax(y_hat), он возвращает 1.

Ответы [ 2 ]

1 голос
/ 18 апреля 2019

Вот один путь после argmax с numpy трансляцией

a.argmax(axis = 1)[:,None]

Или

a[:,None].argmax(-1)
1 голос
/ 18 апреля 2019

np.argmax принимает аргумент ключевого слова axis .Используйте это.

Это axis=0 для столбцов, axis=1 для строк.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...