Оптимизация цикла с помощью Numba для устойчивости к ошибкам - PullRequest
0 голосов
/ 22 апреля 2020

У меня есть сомнения при использовании numba для оптимизации. Я кодирую итерацию с фиксированной точкой, чтобы вычислить значение определенного массива с именем gamma, который удовлетворяет уравнению f (gamma) = gamma. Я пытаюсь оптимизировать эту функцию с помощью python пакета Numba. Это выглядит следующим образом.

@jit
def fixed_point(gamma_guess):
    for i in range(17):
        gamma_guess=f(gamma_guess)
    return gamma_guess

Numba способна хорошо оптимизировать эту функцию, потому что она знает, сколько раз она будет выполнять операцию, 17 раз, и работает быстро. Но мне нужно контролировать допуск ошибки моей желаемой гаммы, я имею в виду, что разница гаммы и следующей, полученной с помощью итерации с фиксированной точкой, должна быть меньше некоторого числа epsilon = 0,01, затем я попытался

@jit
def fixed_point(gamma_guess):
    err=1000
    gamma_old=gamma_guess.copy()
    while(error>0.01):
        gamma_guess=f(gamma_guess)
        err=np.max(abs(gamma_guess-gamma_old))
        gamma_old=gamma_guess.copy()
    return gamma_guess

Он также работает и вычисляет желаемый результат, но не так быстро, как в прошлой реализации, он намного медленнее. Я думаю, это потому, что Numba не может хорошо оптимизировать цикл while, поскольку мы не знаем, когда он остановится. Есть ли способ, которым я могу оптимизировать это и работать так же быстро, как в прошлой реализации?

Редактировать:

Вот то, что я использую

from scipy import fftpack as sp
S=0.01
Amu=0.7
@jit 
def f(gammaa,z,zal,kappa):
    ka=sp.diff(kappa)
    gamma0=gammaa
    for i in range(N):
        suma=0
        for j in range(N):
            if (abs(j-i))%2 ==1:
                if((z[i]-z[j])==0):
                    suma+=(gamma0[j]/(z[i]-z[j]))   
        gamma0[i]=2.0*Amu*np.real(-(zal[i]/z[i])+zal[i]*(1.0/(2*np.pi*1j))*suma*2*h)+S*ka[i]
    return  gamma0

Я всегда используйте np.ones(2048)*0.5 в качестве исходного предположения, а другие параметры, которые я передаю своей функции: z=np.cos(alphas)+1j*(np.sin(alphas)+0.1), zal=-np.sin(alphas)+1j*np.cos(alphas), kappa=np.ones(2048) и alphas=np.arange(0,2*np.pi,2*np.pi/2048)

1 Ответ

0 голосов
/ 22 апреля 2020

Я сделал небольшой тестовый скрипт, чтобы посмотреть, смогу ли я воспроизвести вашу ошибку:

import numba as nb

from IPython import get_ipython
ipython = get_ipython()

@nb.jit(nopython=True)
def f(x):
    return (x+1)/x


def fixed_point_for(x):
    for _ in range(17):
        x = f(x)
    return x

@nb.jit(nopython=True)
def fixed_point_for_nb(x):
    for _ in range(17):
        x = f(x)
    return x

def fixed_point_while(x):
    error=1
    x_old = x
    while error>0.01:
        x = f(x)
        error = abs(x_old-x)
        x_old = x
    return x

@nb.jit(nopython=True)
def fixed_point_while_nb(x):
    error=1
    x_old = x
    while error>0.01:
        x = f(x)
        error = abs(x_old-x)
        x_old = x
    return x

print("for loop without numba:")
ipython.magic("%timeit fixed_point_for(10)")

print("for loop with numba:")
ipython.magic("%timeit fixed_point_for_nb(10)")

print("while loop without numba:")
ipython.magic("%timeit fixed_point_while(10)")

print("for loop with numba:")
ipython.magic("%timeit fixed_point_while_nb(10)")

Поскольку я не знаю о вашем f, я просто использовал самую простую стабилизирующую функцию, которую я мог придумать. Затем я запускал тесты с numba и без него, оба раза с циклами for и while. Результаты на моей машине:

for loop without numba:
3.35 µs ± 8.72 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
for loop with numba:
282 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
while loop without numba:
1.86 µs ± 7.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
for loop with numba:
214 ns ± 1.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

Возникают следующие мысли:

  • Не может быть, чтобы ваша функция не оптимизировалась, поскольку ваша for l oop - это быстро (по крайней мере, вы так сказали; тестировали ли вы без numba?).
  • Может случиться так, что вашей функции потребуется гораздо больше циклов, чтобы сходиться, как вы думаете
  • Мы используем разные версии программного обеспечения. Мои версии:
    • numba 0.49.0
    • numpy 1.18.3
    • python 3.8.2
...