Эффективно вернуть индекс первого значения, удовлетворяющего условию в массиве - PullRequest
0 голосов
/ 27 октября 2018

Мне нужно найти индекс первого значения в массиве 1d NumPy или числовой ряд Pandas, удовлетворяющий условию.Массив большой, и индекс может находиться рядом с началом или концом массива, или , условие может не выполняться вообще.Я не могу сказать заранее, что более вероятно.Если условие не выполняется, возвращаемое значение должно быть -1.Я рассмотрел несколько подходов.

Попытка 1

# func(arr) returns a Boolean array
idx = next(iter(np.where(func(arr))[0]), -1)

Но это часто слишком медленно, поскольку func(arr) применяет векторизованную функцию ко всему массиву * чем остановиться, когда условие выполнено.В частности, это дорого, когда условие выполняется около начала массива.

Попытка 2

np.argmax немного быстрее,но не может определить, когда условие никогда не выполнено:

np.random.seed(0)
arr = np.random.rand(10**7)

assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999)

%timeit next(iter(np.where(arr > 0.999999)[0]), -1)  # 21.2 ms
%timeit np.argmax(arr > 0.999999)                    # 17.7 ms

np.argmax(arr > 1.0) возвращает 0, т. е. случай, когда условие не выполнено.

Попытка 3

# func(arr) returns a Boolean scalar
idx = next((idx for idx, val in enumerate(arr) if func(arr)), -1)

Но это слишком медленно, когда условие выполняется около конца массива.Предположительно, это потому, что выражение генератора имеет дорогостоящие накладные расходы из-за большого числа __next__ вызовов.

Является ли это всегда компромиссом или есть способ для универсального func,для эффективного извлечения первого индекса?

Сравнительный анализ

Для сравнительного анализа предположим, что func находит индекс, когда значение больше заданной константы:

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

# Start of array benchmark
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

# End of array benchmark
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms

1 Ответ

0 голосов
/ 27 октября 2018

numba

С помощью numba можно оптимизировать оба сценария.Синтаксически, вам нужно только создать функцию с простым циклом for:

from numba import njit

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

idx = get_first_index_nb(A, 0.9)

Numba повышает производительность за счет JIT ("Just In Time") компиляции кода и использования оптимизации на уровне процессора .Обычный цикл for без декоратора @njit обычно будет на медленнее , чем методы, которые вы уже пробовали для случая, когда условие выполняется поздно.

Для числового ряда Pandas df['data'] вы можете просто передать представление NumPy в JIT-скомпилированную функцию:

idx = get_first_index_nb(df['data'].values, 0.9)

Обобщение

Поскольку numba разрешает функциив качестве аргументов и предполагая, что переданная функция также может быть JIT-скомпилирована, вы можете получить метод для вычисления n -го индекса, где выполняется условие для произвольного func.

@njit
def get_nth_index_count(A, func, count):
    c = 0
    for i in range(len(A)):
        if func(A[i]):
            c += 1
            if c == count:
                return i
    return -1

@njit
def func(val):
    return val > 0.9

# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)

Для 3-го последнего значения вы можете задать обратное значение, arr[::-1] и отменить результат с len(arr) - 1, - 1, необходимым для учета 0-индексации.

Сравнительный анализ производительности

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

def get_first_index_np(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

%timeit get_first_index_nb(arr, m)                                 # 375 ns
%timeit get_first_index_np(arr, m)                                 # 2.71 µs
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

%timeit get_first_index_nb(arr, n)                                 # 204 µs
%timeit get_first_index_np(arr, n)                                 # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...