Мне нужно найти индекс первого значения в массиве 1d NumPy или числовой ряд Pandas, удовлетворяющий условию.Массив большой, и индекс может находиться рядом с началом или концом массива, или , условие может не выполняться вообще.Я не могу сказать заранее, что более вероятно.Если условие не выполняется, возвращаемое значение должно быть -1
.Я рассмотрел несколько подходов.
Попытка 1
# func(arr) returns a Boolean array
idx = next(iter(np.where(func(arr))[0]), -1)
Но это часто слишком медленно, поскольку func(arr)
применяет векторизованную функцию ко всему массиву * чем остановиться, когда условие выполнено.В частности, это дорого, когда условие выполняется около начала массива.
Попытка 2
np.argmax
немного быстрее,но не может определить, когда условие никогда не выполнено:
np.random.seed(0)
arr = np.random.rand(10**7)
assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999)
%timeit next(iter(np.where(arr > 0.999999)[0]), -1) # 21.2 ms
%timeit np.argmax(arr > 0.999999) # 17.7 ms
np.argmax(arr > 1.0)
возвращает 0
, т. е. случай, когда условие не выполнено.
Попытка 3
# func(arr) returns a Boolean scalar
idx = next((idx for idx, val in enumerate(arr) if func(arr)), -1)
Но это слишком медленно, когда условие выполняется около конца массива.Предположительно, это потому, что выражение генератора имеет дорогостоящие накладные расходы из-за большого числа __next__
вызовов.
Является ли это всегда компромиссом или есть способ для универсального func
,для эффективного извлечения первого индекса?
Сравнительный анализ
Для сравнительного анализа предположим, что func
находит индекс, когда значение больше заданной константы:
# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np
np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999
# Start of array benchmark
%timeit next(iter(np.where(arr > m)[0]), -1) # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1) # 2.5 µs
# End of array benchmark
%timeit next(iter(np.where(arr > n)[0]), -1) # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1) # 39.2 ms