NumPy: самый быстрый способ получить argmax сумм строк трехмерного массива - PullRequest
1 голос
/ 26 июня 2019

Предположим, у меня есть 3-мерный массив с размерами 10x10000x5. Интерпретируя этот массив как 10 «подмассивов», каждый из которых содержит 10000 строк и 5 столбцов, я хочу сделать для каждой строки:

(1) Вычислить сумму строки в каждом из 10 подмассивов.

(2) Определите, какой подмассив дает наибольшую сумму.

Пример показан ниже. Я делаю выше, но только для первых двух строк, где firstrow - сумма первой строки каждого подмассива, а secondrow - сумма второй строки каждого подмассива. Затем я использую np.argmax (), чтобы найти подмассив, который дает наибольшую сумму. Но я хочу сделать это для всех 10000 строк, а не только для первых двух.

import numpy as np
np.random.seed(777)
A = np.random.randn(10,10000,5)

first = [None]*10
second = [None]*10
for i in range(10):
    firstrow[i] = A[i].sum(axis=1)[0]
    secondrow[i] = A[i].sum(axis=1)[1]

np.argmax(np.array(firstrow)) # Sub-array 9 yields the highest sum
np.argmax(np.array(secondrow)) # Sub-array 8 yields the highest sum
#...

Какой самый быстрый способ сделать это для всех 10000 строк?

1 Ответ

1 голос
/ 26 июня 2019

Вы можете сделать это так:

result = A.sum(2).argmax(0)

Протестировано в вашем примере:

import numpy as np

np.random.seed(777)
A = np.random.randn(10, 10000, 5)

result = A.sum(2).argmax(0)

# Check against loop
first = [None] * 10
second = [None] * 10
for i in range(10):
    first[i] = A[i].sum(axis=1)[0]
    second[i] = A[i].sum(axis=1)[1]

print(result[0], np.argmax(np.array(first)))
# 9 9
print(result[1], np.argmax(np.array(second)))
# 8 8
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...