Использование np.min с вводом списка в функции numba - PullRequest
0 голосов
/ 03 марта 2019

В чем проблема с использованием np.min здесь?Почему numba не нравится использование списка в этой функции, есть ли другой способ заставить np.min работать?

from numba import njit
import numpy as np

@njit
def availarray(length):
    out=np.ones(14)
    if length>0:
        out[0:np.min([int(length),14])]=0
    return out

availarray(3)

Функция отлично работает с min, но np.min должна быть быстрее...

Ответы [ 2 ]

0 голосов
/ 03 марта 2019

Проблема в том, что для numba-версии np.min требуется array в качестве ввода.

from numba import njit
import numpy as np

@njit
def test_numba_version_of_numpy_min(inp):
    return np.min(inp)

>>> test_numba_version_of_numpy_min(np.array([1, 2]))  # works
1

>>> test_numba_version_of_numpy_min([1, 2]) # doesn't work
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function amin at 0x000001B5DBDEE598>) with argument(s) of type(s): (reflected list(int64))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.

Лучшим решением было бы просто использовать numba-версию Pythons min:

from numba import njit
import numpy as np

@njit
def availarray(length):
    out = np.ones(14)
    if length > 0:
        out[0:min(length, 14)] = 0
    return out

Поскольку np.min и min на самом деле являются версиями этих функций Numba (по крайней мере, в njit ted функциях), min также должна быть намного быстрее в этом случае.Однако это вряд ли будет заметно, потому что выделение массива и установка некоторых элементов в ноль будут основными вкладчиками времени выполнения здесь.

Обратите внимание, что вам даже не нужен вызов min здесь - потому чтонарезка неявно останавливается в конце массива, даже если используется больший индекс остановки:

from numba import njit
import numpy as np

@njit
def availarray(length):
    out = np.ones(14)
    if length > 0:
        out[0:length] = 0
    return out
0 голосов
/ 03 марта 2019

Чтобы ваш код работал с numba, вам придется применить np.min к массиву NumPy, что означает, что вам придется преобразовать ваш список [int(length),14] в массив NumPy следующим образом

from numba import njit
import numpy as np

@njit
def availarray(length):
    out=np.ones(14)
    if length>0:
        out[0:np.min(np.array([int(length),14]))]=0   
    return out

availarray(3)
# array([0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
...