Numba и многомерные дополнения - не работают с numpy .newaxis? - PullRequest
1 голос
/ 22 апреля 2020

Пытаясь ускорить алгоритм DP на python, Нумба казался подходящим кандидатом.

Я делаю вычитание двумерного массива с одномерным массивом, который доставляет трехмерный массив. Затем я использую .argmin() вдоль 3-го измерения, чтобы получить 2D-массив. Это прекрасно работает с numpy, но не с numba.

Игрушечный код, воспроизводящий проблему:

from numba import jit
import numpy as np

inflow      = np.arange(1,0,-0.01)                  # Dim [T]
actions     = np.arange(0,1,0.05)                   # Dim [M]
start_lvl   = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]
disc_lvl    = np.arange(0,1000)                     # Dim [O]

@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
    for i in range(0,100):
        # Calculate new level at time i
        new_lvl = start_lvl + inflow[i] + actions       # Dim [N x M]

        # For each new_level element, find closest discretized level
        diff    = (disc_lvl-new_lvl[:,:,np.newaxis])    # Dim [N x M x O]
        idx_lvl = abs(diff).argmin(axis=2)              # Dim [N x M]

        return True

# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)

Почему не работает приведенный выше код? Это делает при вынимании @jit(nopython=True). Есть ли обходной путь, чтобы следующий расчет работал с Numba?

Я пробовал варианты с numpy repeat & expand_dims, а также без явного определения типов ввода функции jit.

Ответы [ 2 ]

2 голосов
/ 23 апреля 2020

Есть несколько вещей, которые необходимо изменить, чтобы он работал:

  1. Добавление измерения с помощью arr[:, :, None]: для Numba это выглядит как getitem, поэтому предпочитайте использовать reshape
  2. Используйте np.abs вместо встроенного abs
  3. Аргумент argmin с axis ключевым словом не реализован . Предпочитаю использовать циклы, которые Numba предназначена для оптимизации.

Со всем этим исправлено, вы можете запустить функцию jlit:

from numba import jit
import numpy as np

inflow = np.arange(1,0,-0.01)  # Dim [T]
actions = np.arange(0,1,0.05)  # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]
disc_lvl = np.arange(0,1000)  # Dim [O]

@jit(nopython=True)
def my_func(disc_lvl, actions, start_lvl, inflow):
    for i in range(0,100):
        # Calculate new level at time i
        new_lvl = start_lvl + inflow[i] + actions  # Dim [N x M]

        # For each new_level element, find closest discretized level
        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]

        idx_lvl = np.empty(new_lvl.shape)
        for i in range(diff.shape[0]):
            for j in range(diff.shape[1]):
                idx_lvl[i, j] = diff[i, j, :].argmin()

        return True

# function works fine without numba
success = my_func(disc_lvl, actions, start_lvl, inflow)
0 голосов
/ 24 апреля 2020

Ниже приведен исправленный код моего первого поста, который вы можете выполнить с режимом jittle библиотеки numba и без него (удалив строку, начинающуюся с @jit). В этом примере я наблюдал увеличение скорости в 2 раза.

from numba import jit
import numpy as np
import datetime as dt

inflow = np.arange(1,0,-0.01)                       # Dim [T]
nbTime = np.shape(inflow)[0]
actions = np.arange(0,1,0.01)                       # Dim [M]
start_lvl = np.random.rand(500).reshape(-1,1)*49    # Dim [Nx1]
disc_lvl = np.arange(0,1000)                        # Dim [O]

@jit(nopython=True)
def my_func(nbTime, disc_lvl, actions, start_lvl, inflow):
    # Initialize result 
    res = np.empty((nbTime,np.shape(start_lvl)[0],np.shape(actions)[0]))

    for t in range(0,nbTime):
        # Calculate new level at time t
        new_lvl = start_lvl + inflow[t] + actions  # Dim [N x M]      
        print(t)

        # For each new_level element, find closest discretized level
        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)
        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]

        idx_lvl = np.empty(new_lvl.shape)
        for i in range(diff.shape[0]):
            for j in range(diff.shape[1]):
                idx_lvl[i, j] = diff[i, j, :].argmin()

        res[t,:,:] = idx_lvl

    return res

# Call function and print running time
start_time = dt.datetime.now()
result = my_func(nbTime, disc_lvl, actions, start_lvl, inflow)
print('Execution time :',(dt.datetime.now() - start_time))
...