( EDITED )
( EDIT2 : добавлена более специализированная версия JIT для решения проблем при использовании np.sort()
с numba
.)
( EDIT3 : включено время для рекурсивного подхода с медианным поворотом от @ hilberts_drinking_problem's answer )
Я не 100%, что вы после, потому что первые две строки вашего кода, кажется, ничего не делают, но после @hilberts_drinking_problem я отредактировал свой ответ, я предполагаю, что у вас есть опечатка и:
sum_ = np.sum(arr[:i])
должно быть:
sum_ = np.sum(asorted[:i])
Тогда ваше решение можно записать в виде функции, такой как:
import numpy as np
def min_sum_threshold_orig(arr, threshold=0.5):
idx = np.argsort(arr)
arr_sorted = arr[idx][::-1]
sum_ = 0
i = 0
while sum_ < threshold:
i += 1
sum_ = np.sum(arr_sorted[:i])
return i
Однако:
- Вместо
np.argsort()
и индексирование, которое вы можете использовать np.sort()
напрямую - нет необходимости вычислять всю сумму на каждой итерации, но вместо этого вы можете использовать сумму из предыдущей итерации
- Использование while l oop рискованно, потому что если
threshold
достаточно высоко (> 1.0
с вашим предположением), тогда l oop ever end
Обращаясь к этим точкам, можно получить:
def min_sum_threshold(arr, threshold=0.5):
arr = np.sort(arr)[::-1]
sum_ = 0
for i in range(arr.size):
sum_ += arr[i]
if sum_ >= threshold:
break
return i + 1
В приведенном выше описании явное зацикливание становится узким местом. Хороший способ решения этой проблемы - использовать numba
:
import numba as nb
min_sum_threshold_nbn = nb.jit(min_sum_threshold)
min_sum_threshold_nbn.__name__ = 'min_sum_threshold_nbn'
Но это может быть не самый эффективный подход, поскольку numba
является относительно медленным при создании новых массивов. Возможно, более быстрый подход заключается в использовании arr.sort()
вместо np.sort()
, потому что это на месте, что позволяет избежать создания нового массива:
@nb.jit
def min_sum_thres_nb_inplace(arr, threshold=0.5):
arr.sort()
sum_ = 0
for i in range(arr.size - 1, -1, -1):
sum_ += arr[i]
if sum_ >= threshold:
break
return arr.size - i
В качестве альтернативы можно выполнить JIT только часть кода после сортировки:
@nb.jit
def _min_sum_thres_nb(arr, threshold=0.5):
sum_ = 0.0
for i in range(arr.size):
sum_ += arr[i]
if sum_ >= threshold:
break
return i + 1
def min_sum_thres_nb(arr, threshold=0.5):
return _min_sum_thres_nb(np.sort(arr)[::-1], threshold)
Разница между ними будет минимальной для больших входов. Для меньшего из них min_sum_thres_nb()
будет зависеть от сравнительно медленного вызова дополнительной функции. Из-за ошибок в функциях бенчмаркинга, которые изменяют их входные данные, min_sum_thres_nb_inplace()
опускается в бенчмарках с пониманием того, что для очень маленьких входов он такой же быстрый, как min_sum_thres_nbn()
, а для более крупных он имеет практически те же характеристики, что и min_sum_thres_nb()
.
В качестве альтернативы можно использовать векторизованные подходы, как в @ yatu's answer :
def min_sum_threshold_np_sum(arr, threshold=0.5):
return np.sum(np.cumsum(np.sort(arr)[::-1]) < threshold) + 1
или, лучше, использовать np.searchsorted()
, что позволяет избежать создания ненужных временный массив со сравнением:
def min_sum_threshold_np_ss(arr, threshold=0.5):
return np.searchsorted(np.cumsum(np.sort(arr)[::-1]), threshold) + 1
или, если предположить, что сортировка всего массива излишне дорогая:
def min_sum_threshold_np_part(arr, threshold=0.5):
n = arr.size
m = np.int(size * threshold) + 1
part_arr = np.partition(arr, n - m)[n - m:]
return np.searchsorted(np.cumsum(np.sort(arr)[::-1]), threshold) + 1
Еще более сложный подход с использованием рекурсии и медианного поворота:
def min_sum_thres_rec(arr, threshold=0.5, cutoff=64):
n = arr.size
if n <= cutoff:
return np.searchsorted(np.cumsum(np.sort(arr)[::-1]), threshold) + 1
else:
m = n // 2
partitioned = np.partition(arr, m)
low = partitioned[:m]
high = partitioned[m:]
sum_high = np.sum(high)
if sum_high >= threshold:
return min_sum_thres_rec(high, threshold)
else:
return min_sum_thres_rec(low, threshold - sum_high) + high.size
(последние три адаптированы из ответа @ hilberts_drinking_problem )
Сравнительный анализ с входными данными, сгенерированными из этого:
def gen_input(n, a=0, b=10000):
arr = np.random.randint(a, b, n)
arr = arr / np.sum(arr)
return arr
дает следующее:
Они указывают на то, что для достаточно малых входов, утверждение numba
Aч - самый быстрый, но как только вход превышает ~ 600 элементов для наивного подхода или ~ 900 для оптимизированного , подход NumPy, который использует np.partition()
, в то же время менее эффективно использует память, быстрее.
В конечном итоге, после ~ 4000 элементов, min_sum_thres_rec()
становится быстрее, чем все другие предложенные методы. Может быть возможно написать более быструю реализацию этого метода на основе чисел.
Обратите внимание, что оптимизированный numba
подход строго быстрее, чем наивные NumPy протестированные подходы.