Более быстрая альтернатива np.where для отсортированного массива - PullRequest
4 голосов
/ 10 января 2020

Учитывая большой массив a, который сортируется по каждой строке, есть ли более быстрая альтернатива numpy np.where, чтобы найти индексы, где min_v <= a <= max_v? Я бы предположил, что использование отсортированной природы массива должно ускорить процесс.

Вот пример установки, использующей np.where для поиска заданных индексов в большом массиве.

import numpy as np

# Initialise an example of an array in which to search
r, c = int(1e2), int(1e6)
a = np.arange(r*c).reshape(r, c)

# Set up search limits
min_v = (r*c/2)-10
max_v = (r*c/2)+10

# Find indices of occurrences
idx = np.where(((a >= min_v) & (a <= max_v)))

Ответы [ 2 ]

2 голосов
/ 10 января 2020

Вы можете использовать np.searchsorted:

import numpy as np

r, c = 10, 100
a = np.arange(r*c).reshape(r, c)

min_v = ((r * c) // 2) - 10
max_v = ((r * c) // 2) + 10

# Old method
idx = np.where(((a >= min_v) & (a <= max_v)))

# With searchsorted
i1 = np.searchsorted(a.ravel(), min_v, 'left')
i2 = np.searchsorted(a.ravel(), max_v, 'right')
idx2 = np.unravel_index(np.arange(i1, i2), a.shape)
print((idx[0] == idx2[0]).all() and (idx[1] == idx2[1]).all())
# True
1 голос
/ 10 января 2020

Когда я использую np.searchsorted со 100 миллионами чисел в исходном примере с устаревшей NumPy версией 1.12.1 (не могу сказать о более новых версиях), это не намного быстрее, чем np.where:

>>> import timeit
>>> timeit.timeit('np.where(((a >= min_v) & (a <= max_v)))', number=10, globals=globals())
6.685825735330582
>>> timeit.timeit('np.searchsorted(a.ravel(), [min_v, max_v])', number=10, globals=globals())
5.304438766092062

Но, несмотря на NumPy документы для searchsorted скажем Эта функция использует тот же алгоритм, что и встроенные функции python bisect.bisect_left и bisect.bisect_right, последние намного быстрее:

>>> import bisect
>>> timeit.timeit('bisect.bisect_left(a.base, min_v), bisect.bisect_right(a.base, max_v)', number=10, globals=globals())
0.002058468759059906

Поэтому я бы использовал это:

idx = np.unravel_index(range(bisect.bisect_left(a.base, min_v),
                             bisect.bisect_right(a.base, max_v)), a.shape)
...