NumPy - получить индекс строки с суммой строк больше 10 - PullRequest
2 голосов
/ 27 сентября 2019

У меня есть следующий массив:

a = np.array([1,2,9], [5,2,4], [1,2,3])

Задача - найти индексы для всех строк с суммой строк, превышающей 10 , в моем примере результатдолжен выглядеть как [0, 1]

Мне нужен фильтр, подобный фильтру, рекомендованному в этом посте: Фильтровать строки массива numpy?

Однако я тольконужны индексы, а не фактические значения или их собственный массив.

Мой текущий код выглядит следующим образом:

temp = a[np.sum(a, axis=1) > 5]

Как получить начальные индексы отфильтрованных строк?

Ответы [ 4 ]

4 голосов
/ 27 сентября 2019

Вы можете использовать np.argwhere() примерно так:

>>> import numpy as np

>>> a = np.array([[1,2,9], [5,2,4], [1,2,3]])
>>> np.argwhere(np.sum(a, axis=1) > 10)
[[0]
 [1]]
1 голос
/ 27 сентября 2019

Вы можете проверить, где сумма больше 10 и получить индексы с помощью np.flatnonzero:

a = np.array([[1,2,9], [5,2,4], [1,2,3]])

np.flatnonzero(a.sum(1) > 10)
# array([0, 1], dtype=int64)
0 голосов
/ 27 сентября 2019

Вы можете просто использовать:

temp = np.sum(a, axis=1) > 10
np.arange(len(a))[temp]
0 голосов
/ 27 сентября 2019

я пробую более одного кода.Лучше всего выглядит вторая версия:

import numpy as np
a = np.array([[1,2,9], [5,2,4], [1,2,1]])
print(a)

%timeit temp = a[np.sum(a, axis=1) > 5]
temp = a[np.sum(a, axis=1) > 5]
print(temp)

%timeit temp = [n for n, curr in enumerate(a) if sum(curr) > 5 ]
temp = [n for n, curr in enumerate(a) if sum(curr) > 5 ]
print(temp)

%timeit temp = np.argwhere(np.sum(a, axis=1) > 5)
temp = np.argwhere(np.sum(a, axis=1) > 5)
print(temp)

%timeit temp = np.flatnonzero(a.sum(1) > 10)
temp = np.flatnonzero(a.sum(1) > 10)
print(temp)

Результаты:

[[1 2 9]
 [5 2 4]
 [1 2 1]]
The slowest run took 12.37 times longer than the fastest. This could mean that     an intermediate result is being cached.
100000 loops, best of 3: 7.47 µs per loop
[[1 2 9]
 [5 2 4]]
100000 loops, best of 3: 5.09 µs per loop
[0, 1]
The slowest run took 9.83 times longer than the fastest. This could mean that     an intermediate result is being cached.
100000 loops, best of 3: 13.3 µs per loop
[[0]
 [1]]
The slowest run took 6.78 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.8 µs per loop
[0 1]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...