Пытаясь ускорить алгоритм DP на python, Нумба казался подходящим кандидатом.
Я делаю вычитание двумерного массива с одномерным массивом, который доставляет трехмерный массив. Затем я использую .argmin()
вдоль 3-го измерения, чтобы получить 2D-массив. Это прекрасно работает с numpy, но не с numba.
Игрушечный код, воспроизводящий проблему:
from numba import jit
import numpy as np
inflow = np.arange(1,0,-0.01) # Dim [T]
actions = np.arange(0,1,0.05) # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49 # Dim [Nx1]
disc_lvl = np.arange(0,1000) # Dim [O]
@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
for i in range(0,100):
# Calculate new level at time i
new_lvl = start_lvl + inflow[i] + actions # Dim [N x M]
# For each new_level element, find closest discretized level
diff = (disc_lvl-new_lvl[:,:,np.newaxis]) # Dim [N x M x O]
idx_lvl = abs(diff).argmin(axis=2) # Dim [N x M]
return True
# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)
Почему не работает приведенный выше код? Это делает при вынимании @jit(nopython=True)
. Есть ли обходной путь, чтобы следующий расчет работал с Numba?
Я пробовал варианты с numpy repeat & expand_dims, а также без явного определения типов ввода функции jit.