Как получить вектор.из двумерного массива с использованием argmax? - PullRequest
0 голосов
/ 07 февраля 2019

У меня есть следующее numpy.ndarray:

testArr:

array([[  2.55053788e-01,   6.25406146e-01,   1.19271643e-01,
          2.68359261e-04],
       [  2.59611636e-01,   0.19562805e-01,   1.20518960e-01,
          3.06535745e-01],
       [  8.52524495e-01,   5.24317825e-01,   1.22851081e-01,
          3.06610862e-04],
       [  2.55068243e-01,   6.24345124e-01,   1.20263465e-01,
          3.23178538e-04],
       [  2.46678621e-01,   6.29301071e-01,   1.23693809e-01,
          3.26490292e-04]], dtype=float32)

Если я наберу testVec = np.argmax(testArr), я получу одно число.Как получить вектор 0, 1 или 2, в зависимости от максимального значения в каждой строке двумерного массива testArr?

Ожидаемый результат:

[1, 3, 0, 1, 1]

Ответы [ 2 ]

0 голосов
/ 07 февраля 2019

По умолчанию np.argmax дает индекс максимального значения в сглаженном массиве.Чтобы получить максимум по одному измерению (например, максимальное значение в каждой строке), необходимо указать ключевое слово аргумент axis.Это должно быть целое число: 0 для столбцов, 1 для строк.(Или любое целое число до n-1, если ваш массив имеет n измерений.)

import numpy as np
testArr = np.array([[  2.55053788e-01,   6.25406146e-01,   1.19271643e-01,
          2.68359261e-04],
       [  2.59611636e-01,   0.19562805e-01,   1.20518960e-01,
          3.06535745e-01],
       [  8.52524495e-01,   5.24317825e-01,   1.22851081e-01,
          3.06610862e-04],
       [  2.55068243e-01,   6.24345124e-01,   1.20263465e-01,
          3.23178538e-04],
       [  2.46678621e-01,   6.29301071e-01,   1.23693809e-01,
          3.26490292e-04]], dtype=np.float32)
np.argmax(testArr, axis=1)
>>> array([1, 3, 0, 1, 1])
0 голосов
/ 07 февраля 2019

Если вы посмотрите на documentation, вы увидите, что есть параметр axis, который позволяет вам выбирать, по какой оси вы хотите выполнить операцию.Из документов:

Возвращает индексы максимальных значений вдоль оси.

В этом случае вы хотите:

np.argmax(a, axis=1)
# array([1, 3, 0, 1, 1], dtype=int64)
...