Для меня работает следующее:
import numpy as np
from numba import jit
@jit(nopython=True)
def rk4(t_prev, x_prev, derivs, dt):
k1 = dt * derivs(t_prev, x_prev)
k2 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k1)
k3 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k2)
k4 = dt * derivs(t_prev + dt, x_prev + k3)
x_next = x_prev + 1/6*k1 + 1/3*k2 + 1/3*k3 + 1/6*k4
return x_next
global k, x_0, v_0, t_0, t_f
k = 1
x_0 = 0
v_0 = np.sqrt(k)
t_0 = 0
t_f = 10
dtList = np.logspace(0, -5, 1000)
@jit(nopython=True)
def derivs(t, X):
deriv = np.zeros(2)
deriv[0] = X[1]
deriv[1] = -k * X[0]
return deriv
@jit(nopython=True)
def err(dt):
tList = np.arange(t_0, t_f + dt, dt)
N = tList.shape[0]
XList = np.zeros((N,2))
XList[0][0], XList[0][1] = x_0, v_0
for i in range(N-1):
XList[i+1] = rk4(tList[i], XList[i], derivs, dt)
error = np.abs(XList[-1][0] - np.sin(10))
return error
print(err(.001))
Обратите внимание, что я сделал только два изменения в вашем коде, чтобы заменить вызовы на np.zeros
, которые передаются в списках, либо на tuple
в2-й случай, или просто голое целое в 1-м случае. См. Следующую проблему для объяснения, почему это так:
https://github.com/numba/numba/issues/3993