Не планировал делать ни одного поста. Я ожидал, что Нумба победит в любых условиях, но этого не должно было быть. Провел несколько тестов на предложенные решения, и результаты оказались несколько интересными, поэтому размещаем здесь. Я собираюсь использовать данные массива, чтобы упростить задачу.
# Proposed solutions
import numpy as np
from numba import njit
# @piRSquared's soln
@njit
def find_first_gt(a, n, value):
while a[n] <= value:
n += 1
return n
# @Ehsan's soln
def numpy_argmax(a, n , value):
return np.argmax(a[n:] > value)
Использование пакета benchit
(несколько инструментов для сравнения, собранные вместе; отказ от ответственности: я его автор) для сравнения предлагаемых решений.
Время и ускорения -
# Benchmark
a = np.arange(1000_000)
n = 0
import benchit
funcs = [find_first_gt, numpy_argmax]
vs = np.linspace(0, len(a)-1, num=20, endpoint=True).astype(int)
inputs = [(a,0,v) for v in vs]
t = benchit.timings(funcs, inputs, multivar=True, input_name='Position of value')
t.plot(logy=False, logx=False, savepath='plot.png')
t.speedups(ref_func_by_index=1).plot('Speedup_with_numba.png')
Если вас интересуют точные цифры ускорения -
In [12]: t.speedups(ref_func_by_index=1)
Out[12]:
Functions find_first_gt Ref:numpy_argmax
Position of value
0 2103.548010 1.0
52631 22.053699 1.0
105263 11.109615 1.0
157894 7.541725 1.0
210526 5.640514 1.0
263157 4.407300 1.0
315789 3.642989 1.0
368420 3.028726 1.0
421052 2.543713 1.0
473683 2.201336 1.0
526315 1.931540 1.0
578946 1.692138 1.0
631578 1.536912 1.0
684209 1.455065 1.0
736841 1.357728 1.0
789472 1.248716 1.0
842104 1.176199 1.0
894735 1.062174 1.0
947367 1.043791 1.0
999999 0.983419 1.0
Вывод: почти во всех условиях numba
делает хорошую работу, если только вы не знаете, что value
находится в самом дальнем конце или в тупике. схемы кеширования не устраивают вас.