Я нашел другой способ ускорить ваш код еще в 23 раза.С тех пор как я ответил на этот вопрос, я узнал больше о пакете Python под названием numba
.Этот пакет действительно умный, и он по сути переводит ваши функции Python в машинный код и может дать вам огромные ускорения (в зависимости от того, что вы делаете).
Я поиграл с этим пакетом и подумализ вашего вопроса, я хотел посмотреть, насколько быстро Numba может сделать это по сравнению с использованием массивов NumPy.Проводя сравнение между вашим кодом, версией, которую я вставил в ответе выше, и оптимизированной по numba версии вашего кода.Я обнаружил, что версия numpy (без цикла) была в 23 раза быстрее, чем ваша версия цикла, но версия numba была в 455 раз быстрее!
В любом случае, я подумал, что вам может понравиться это увидеть.Версия numba примерно в 23 раза быстрее, чем мой ответ выше.Ниже приведены 3 версии кода, который я использовал.Я превратил ваш код в функцию под названием original_loop_function
, версию кода, которую я разместил выше, в функцию numpy_function
, а более быстрой версией кода на языке numba является функция numba_loop_function
.Надеюсь, это все еще полезно для вас
import numpy as np
from numba import jit
cpx_var = np.linspace(0.20,0.80,301)
horn_var = np.linspace(0.20,0.80,301)
plag_var = np.linspace(0.05,0.1,26)
mag_var = np.linspace(0.02,0.06,21)
ap_var = np.linspace(0.002,0.006,3)
def original_loop_function(cpx_var, horn_var, plag_var, mag_var, ap_var):
'''Your original version of this code'''
poss_comb = []
count=0
for i in cpx_var:
for j in horn_var:
for k in plag_var:
for l in mag_var:
for m in ap_var:
count = count+1
if abs((i+j+k+l+m)-1.0)<0.002:
poss_comb.append([i,j,k,l,m])
return poss_comb
def numpy_function(cpx_var, horn_var, plag_var, mag_var, ap_var):
'''Numpy version of your code with no loops'''
cpx_vars, horn_vars, plag_vars, mag_vars, ap_vars = np.meshgrid(cpx_var, horn_var, plag_var, mag_var, ap_var)
selection = abs((cpx_vars + horn_vars + plag_vars + mag_vars + ap_vars)-1.0)<0.002
poss_comb = [list(comb) for comb in zip(cpx_vars[selection], horn_vars[selection], plag_vars[selection], mag_vars[selection], ap_vars[selection])]
return poss_comb
@jit(nopython=True)
def numba_loop_function(cpx_var, horn_var, plag_var, mag_var, ap_var):
'''Numba optimised version of your code'''
poss_comb = []
count=0
for i in cpx_var:
for j in horn_var:
for k in plag_var:
for l in mag_var:
for m in ap_var:
count = count+1
if abs((i+j+k+l+m)-1.0)<0.002:
poss_comb.append([i,j,k,l,m])
return poss_comb