Я действительно новичок в numba
, поэтому я не очень понимаю, почему это происходит.Я запускаю njit
на функции узкого места моей программы (модель Ising), и она замедляет ее.
Моя функция:
@nb.njit#(nb.types.int64[:](nb.types.int64, nb.types.float64[:], nb.types.int64[:], nb.types.int64[:]))
def heat_bath_mcs(size, Z, s, neighbors):
for step in range(size):
choice = int(size*np.random.random())
energy_variation = (s[neighbors[4*choice]]+s[neighbors[4*choice+1]]
+ s[neighbors[4*choice+2]]+s[neighbors[4*choice+3]])
if np.random.random() < 1.0/(1.0+Z[int(energy_variation*0.5)+2]):
s[choice] = +1
else:
s[choice] = -1
return s
size
- целое число, s
и neighbors
- это два списка Python длиной size
с целочисленными значениями и Z
список длины 4 и значений с плавающей запятой.
Где я пытался вывести типы, но это даетэта ошибка (только с использованием того, что после # в первом коде):
TypeError: Нет соответствующего определения для типа (ов) аргумента int64, отраженный список (float64), отраженный список (int64), отраженный список (int64)
Если я распечатываю типы numba
перед каждым вызовом функции, я получаю:
print(nb.typeof(size),nb.typeof(Z),nb.typeof(s),nb.typeof(neighbors))
Результат:
int64 reflected list(float64) reflected list(int64) reflected list(int64)
Итак, мой вопрос: почему это происходит?Я предполагаю, что я делаю что-то не так, как я могу улучшить свой код, чтобы ускорить его?