Я пытаюсь выучить numba, и поэтому в качестве вступительного упражнения я написал простое решение для орбиты:
import numba as nb
import numpy as np
from timeit import default_timer as timer
spec = [('x0', nb.types.float64),
('y0', nb.types.float64),
('vx0', nb.types.float64),
('vy0', nb.types.float64),
('mass', nb.types.float64),
('ax', nb.types.float64),
('ay', nb.types.float64),
('x', nb.types.float64[:]),
('y', nb.types.float64[:]),
('vx', nb.types.float64[:]),
('vy', nb.types.float64[:])]
@nb.jitclass(spec)
class CelestialBody():
def __init__(self, x0, y0, vx0, vy0, mass):
self.x0 = x0
self.y0 = y0
self.vx0 = vx0
self.vy0 = vy0
self.mass = mass
self.ax = 0.0
self.ay = 0.0
@nb.jit(nopython=True, cache=True)
def orbit(bodies, delta_t, nsteps):
# Set up position arrays
for j in range(len(bodies)):
bodies[j].x = np.zeros(nsteps, dtype=np.float64)
bodies[j].y = np.zeros(nsteps, dtype=np.float64)
bodies[j].vx = np.zeros(nsteps, dtype=np.float64)
bodies[j].vy = np.zeros(nsteps, dtype=np.float64)
bodies[j].x[0] = bodies[j].x0
bodies[j].y[0] = bodies[j].y0
bodies[j].vx[0] = bodies[j].vx0
bodies[j].vy[0] = bodies[j].vy0
# Loop over every time step (skip 0 since we have x0 and y0)
for i in range(0, nsteps-1):
# Get gravitational acceleration for each body at current time
for j in range(len(bodies)):
# Reset accelerations
bodies[j].ax = 0.0
bodies[j].ay = 0.0
for k in range(len(bodies)):
if j != k:
# Get distance between objects
dx = bodies[j].x[i] - bodies[k].x[i]
dy = bodies[j].y[i] - bodies[k].y[i]
d = np.sqrt(dx**2. + dy**2.)
# Get acceleration
a = -bodies[k].mass / d**2.
# Separate into x and y components
theta = np.arctan2(dy, dx)
bodies[j].ax += a * np.cos(theta)
bodies[j].ay += a * np.sin(theta)
# Update positions
for j in range(len(bodies)):
bodies[j].vx[i+1] += bodies[j].vx[i] + bodies[j].ax * delta_t
bodies[j].vy[i+1] += bodies[j].vy[i] + bodies[j].ay * delta_t
bodies[j].x[i+1] += bodies[j].x[i] + bodies[j].vx[i] * delta_t +\
0.5 * bodies[j].ax * delta_t**2.
bodies[j].y[i+1] += bodies[j].y[i] + bodies[j].vy[i] * delta_t + 0.5 *\
bodies[j].ay * delta_t**2
return bodies
for i in range(10):
# Set up celestial bodies
sun = CelestialBody(0., 0., 0., 0., 1.)
earth = CelestialBody(1., 0., 0., 6.33, 3.00e-6)
bodies = [sun, earth]
# Set up time info
tf = 100.
delta_t = tf / 365.
nsteps = int(tf / delta_t)
# Orbit
start = timer()
bodies = orbit(bodies, delta_t, nsteps)
end = timer()
print('Time to run: %f' % (end - start))
Код работает и работает без numba.Когда я добавляю numba, я могу совмещать в себе как класс, так и функцию, и все работает отлично, обеспечивая хорошую скорость.Однако, когда я пытаюсь кэшировать функцию jitt с помощью cache = True, я получаю KeyError:
File "/usr/local/lib/python3.6/dist-packages/numba/caching.py", line 482, in save
data_name = overloads[key]
KeyError: ((reflected list(instance.jitclass.CelestialBody#2cef1b8<x0:float64,
y0:float64,vx0:float64,vy0:float64,mass:float64,ax:float64,ay:float64,
x:array(float64, 1d, A),y:array(float64, 1d, A),vx:array(float64, 1d, A),
vy:array(float64, 1d, A)>), float64, int64), ('x86_64-unknown-linux-gnu',
'skylake', '+adx,+aes,+avx,+avx2,-avx512bitalg,-avx512bw,-avx512cd,-avx512dq,
-avx512er,-avx512f,-avx512ifma,-avx512pf,-avx512vbmi,-avx512vbmi2,-avx512vl,
-avx512vnni,-avx512vpopcntdq,+bmi,+bmi2,-cldemote,+clflushopt,-clwb,-clzero,+cmov,
+cx16,+f16c,+fma,-fma4,+fsgsbase,-gfni,+invpcid,-lwp,+lzcnt,+mmx,+movbe,-movdir64b,
-movdiri,-mwaitx,+pclmul,-pconfig,-pku,+popcnt,-prefetchwt1,+prfchw,-ptwrite,
-rdpid,+rdrnd,+rdseed,-rtm,+sahf,+sgx,-sha,-shstk,+sse,+sse2,+sse3,+sse4.1,
+sse4.2,-sse4a,+ssse3,-tbm,-vaes,-vpclmulqdq,-waitpkg,-wbnoinvd,-xop,+xsave,
+xsavec,+xsaveopt,+xsaves'))
Я понимаю, что большинство из вышеперечисленных - это флаги компилятора и такие, и, вероятно, ненужные, но я не былне уверен, поэтому я решил, что я бы включил его.
Есть также ошибка рассола:
_pickle.PicklingError: Can't pickle <class '__main__.CelestialBody'>: it's not the same object as __main__.CelestialBody
Я пытался посмотреть этот вопрос , но насколькокак я могу сказать, нет ошибки импорта, и я не перепутал ни с одним из импортируемых модулей.Я также не бегу в ноутбуке Jupyter, просто терминал.Я предполагаю, что это как-то связано с классом "signature" до и после его компиляции, и pickle запутывается в изменениях.Я могу заставить работать кэширование, когда класс не используется.
Я использую Python версии 3.6.7, numpy версию 1.15.4 и numba версию 0.42.1
Так, мой вопрос, что является причиной этой ошибки, которая мешает кэшированию?Спасибо!