NumPy Argmax в массиве с несколькими скобками - PullRequest
0 голосов
/ 20 сентября 2019

У меня есть проблема в применении argmax к массиву, который имеет несколько скобок.В реальной жизни я получаю это в результате тензора pytorch.Здесь я могу привести пример:

a = np.array([[1.0, 1.1],[2.1,2.0]])
np.argmax(a,axis=1)

array([1, 0])

Это правильно.Но:

a = np.array([[[1.0, 1.1]],[[2.1,2.0]]])
np.argmax(a,axis=1)

array([[0, 0],
       [0, 0]])

Это не дает мне того, чего я ожидаю.Учтите, что на самом деле у меня есть уровень внутренних скобок:

a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])

Ответы [ 2 ]

1 голос
/ 20 сентября 2019

Используйте .squeeze() и отрицательный индекс.

a = np.array([[[[1.0, 1.1]]], [[[2.1, 2.0]]]])
np.argmax(a, axis = -1).squeeze()

array([1, 0], dtype=int32)
0 голосов
/ 20 сентября 2019

Возможным решением является увеличение значения оси:

a = np.array([[[[1.0, 1.1]]],[[[2.1,2.0]]]])
np.argmax(a,axis=3)

array([[[1]],
       [[0]]])

Но у меня все еще есть внутренние скобки.

...