Как я могу использовать JIT NUMBA в моей программе, которая содержит только массивы Numpy? - PullRequest
2 голосов
/ 07 октября 2019

Моя программа оценивает ошибку при решении линейного дифференциального уравнения. Используются только массивы NumPy. Когда я пытаюсь использовать jit decorator для функций, которые я определяю, я просто получаю ошибки. Не могли бы вы помочь мне правильно его использовать?

Мой код:

import numpy as np
from numba import jit

def rk4(t_prev, x_prev, derivs, dt):
    k1 = dt * derivs(t_prev, x_prev)
    k2 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k1)
    k3 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k2)
    k4 = dt * derivs(t_prev + dt, x_prev + k3)
    x_next = x_prev + 1/6*k1 + 1/3*k2 + 1/3*k3 + 1/6*k4
    return x_next

global k, x_0, v_0, t_0, t_f

k = 1

x_0 = 0
v_0 = np.sqrt(k)

t_0 = 0
t_f = 10

dtList = np.logspace(0, -5, 1000)


def derivs(t, X):
    deriv = np.zeros([2])
    deriv[0] = X[1]
    deriv[1] = -k * X[0]
    return deriv


def err(dt):
    tList = np.arange(t_0, t_f + dt, dt)
    N = tList.shape[0]
    XList = np.zeros([N,2])
    XList[0][0], XList[0][1] = x_0, v_0
    for i in range(N-1):
        XList[i+1] = rk4(tList[i], XList[i], derivs, dt)
    error = np.abs(XList[-1][0] - np.sin(10))
    return error

print(err(.001))

1 Ответ

1 голос
/ 07 октября 2019

Для меня работает следующее:

import numpy as np
from numba import jit

@jit(nopython=True)
def rk4(t_prev, x_prev, derivs, dt):
    k1 = dt * derivs(t_prev, x_prev)
    k2 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k1)
    k3 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k2)
    k4 = dt * derivs(t_prev + dt, x_prev + k3)
    x_next = x_prev + 1/6*k1 + 1/3*k2 + 1/3*k3 + 1/6*k4
    return x_next

global k, x_0, v_0, t_0, t_f

k = 1

x_0 = 0
v_0 = np.sqrt(k)

t_0 = 0
t_f = 10

dtList = np.logspace(0, -5, 1000)

@jit(nopython=True)
def derivs(t, X):
    deriv = np.zeros(2)
    deriv[0] = X[1]
    deriv[1] = -k * X[0]
    return deriv


@jit(nopython=True)
def err(dt):
    tList = np.arange(t_0, t_f + dt, dt)
    N = tList.shape[0]
    XList = np.zeros((N,2))
    XList[0][0], XList[0][1] = x_0, v_0
    for i in range(N-1):
        XList[i+1] = rk4(tList[i], XList[i], derivs, dt)
    error = np.abs(XList[-1][0] - np.sin(10))
    return error

print(err(.001))

Обратите внимание, что я сделал только два изменения в вашем коде, чтобы заменить вызовы на np.zeros, которые передаются в списках, либо на tuple в2-й случай, или просто голое целое в 1-м случае. См. Следующую проблему для объяснения, почему это так:

https://github.com/numba/numba/issues/3993

...