Время компиляции функции njit становится слишком длинным, когда один параметр представляет собой большой список jitclass - PullRequest
0 голосов
/ 17 января 2020

У меня есть Jitclass, который представляет модель ODE FHNfunc. Мне нужно создать большое количество этого класса и решить функцию для каждого из них. Для этого я хотел бы использовать функцию njit. Моя проблема заключается в том, что время компиляции сходит с ума, когда число используемых моделей jitclass становится большим

здесь таблица для представления моей точки:

number of model / Time to compute 
1 0.48170995712280273
10 1.875981092453003
20 5.989977598190308
30 12.749890327453613
40 22.722237586975098
50 35.16194796562195

Как мне понадобятся тысячи экземпляров jitclass, Я не могу тратить часы на компиляцию функции njit Model_compute.

Нормально ли это поведение или я неправильно написал функцию?

вот полный MWE

import numpy as np
from numba import jitclass, njit
from numba import int32, float64
import time

spec = [('V_init' ,float64),
        ('a' ,float64),
        ('b' ,float64),
        ('g',float64),
        ('dt' ,float64),
        ('NbODEs',int32),
        ('dydx' ,float64[:]),
        ('time' ,float64[:]),
        ('V' ,float64[:]),
        ('W' ,float64[:]),
        ('y'    ,float64[:]) ]

@jitclass(spec, )
class FHNfunc:
    def __init__(self,):
        self.V_init = .04
        self.a= 0.25
        self.b=0.001
        self.g = 0.003
        self.dt = .01
        self.dydx    =np.zeros(2)
        self.y    =np.zeros(2)

    def Eul(self,):
        self.deriv()
        self.y += (self.dydx * self.dt)

    def deriv(self,):
        self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
        self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
        return

@njit
def Model_compute(tp, FH_list):
    for i in range(len(FH_list)):
        res = np.zeros(len(tp))
        for tt, t in enumerate(tp):
            for model in range(len(FH_list)):
                FH_list[model].Eul()
            res[tt] = FH_list[0].y[0]
    return res

class severalInstances():
    def __init__(self):
        self.FH_list = []
        self.Nbmodel = 50
        for i in range(self.Nbmodel):
            self.FH_list.append(FHNfunc())

        self.dt = .01
        self.tp = np.linspace(0, 1000, num=int((1000) / self.dt))

    def Simule(self):
        t0 = time.time()
        res = Model_compute(self.tp, self.FH_list)
        print(self.Nbmodel,time.time()-t0, res)

if __name__ == "__main__":
    instance = severalInstances()
    instance.Simule()
...