( EDITED : исправить проблему измерения времени, поднятую @ max9111 в комментариях, и включить модифицированное numexpr
решение).
Узкое место в конечном итоге может оказаться в пределах np.all()
вызов. Это может быть ускорено с помощью Numba следующим образом:
import numpy as np
import numba as nb
@nb.jit(nopython=True)
def contains_nb(arr, a_arr, b_arr):
m = a_arr.size
arr = arr.reshape(-1, m)
n = arr.shape[0]
result = np.ones(n, dtype=np.bool8)
for i in range(n):
for j in range(m):
if not a_arr[j] < arr[i, j] < b_arr[j]:
result[i] = False
break
return result
Это сравнивается с решением NumPy:
import numpy as np
def contains_np(arr, a_arr, b_arr):
m = a_arr.size
arr = arr.reshape(-1, m)
return np.all((arr >= a_arr) & (arr <= b_arr), axis=1)
, которое я немного упростил по сравнению с вашим подходом (у меня есть пропущенные параметры dim
и strict
, поскольку dim
является избыточным, так как он может быть выведен из размеров a_arr
или b_arr
, в то время как параметр strict
не добавляет много к решению, но может быть легко вводится). Я также предполагаю, что входные данные уже всегда являются массивом NumPy.
Кроме того, решение NumPy можно изменить, чтобы использовать numexpr
, что приводит к третьему подходу. Это вызовет некоторые накладные расходы, но может ускорить вычисления, например:
import numpy as np
import numexpr as ne
def contains_ne(arr, a_arr, b_arr):
m = a_arr.size
arr = arr.reshape(-1, m)
result = ne.evaluate('(arr >= a_arr) & (arr <= b_arr)')
return np.all(result, axis=1)
Можно получить следующие тесты:
Это показывает, что решение Numba неизменно является самым быстрым. Напротив, использование numexpr
представляется бесполезным для диапазона исследуемых параметров.
(доступно полное тестирование здесь )