numba.jit не может скомпилировать np.roll - PullRequest
0 голосов
/ 03 апреля 2020

Я пытаюсь скомпилировать функцию "foo", используя jit

import numpy as np
from numba import jit

dy = 5
@jit
def foo(grid):
    return np.sum([np.roll(np.roll(grid, y, axis = 1), x, axis = 0)
                   for x in (-1, 0, 1) for y in (-1, 0, 1) if x or y], axis=0)


ex_grid = np.random.rand(5,5)>0.5
result = foo(ex_grid)

И я получаю следующую ошибку:

Compilation is falling back to object mode WITH looplifting enabled because Function "foo" failed type inference due to: Invalid use of Function(<function roll at 0x00000161E45C7D90>) with argument(s) of type(s): (array(bool, 2d, C), Literal[int](5), axis=Literal[int](1))
 * parameterized
In definition 0:
    TypeError: np_roll() got an unexpected keyword argument 'axis'

Функция работает, но компиляция не удалась.

Как я могу исправить эту ошибку, совместим ли np.roll с numba, и если нет, есть альтернатива?

1 Ответ

1 голос
/ 03 апреля 2020

Если вы проверите docs , вы увидите, что для np.roll поддерживаются только два первых аргумента, следовательно, он будет выполнять прокрутку только для сплющенного массива (поскольку вы не можете указать ось) .

numpy .roll () (только 2 первых аргумента; сдвиг второго аргумента должен быть целым числом)

Обратите внимание, что в действительности здесь не имеет смысла использовать numba, поскольку вы выполняете одну векторизованную операцию, которая уже будет выполняться очень быстро. Numba будет иметь смысл только в том случае, если вам потребуется l oop над массивом, чтобы применить некоторые логики c.

Таким образом, единственный возможный способ roll строк вашего массива здесь, используя numba, будет l oop над ними:

@njit
def foo(a, dy):
    out = np.empty(a.shape, np.int32)
    for i in range(a.shape[0]):
        out[i] = np.roll(a[i], dy)
    return out

np.allclose(foo(ex_grid, 3).astype(bool), np.roll(ex_grid, 3, axis=1))
# True

Хотя, как уже упоминалось, это будет много медленнее, чем просто с помощью np.roll установка axis=1, поскольку это уже векторизация и все циклы выполняются на уровне C:

ex_grid = np.random.rand(5000,5000)>0.5

%timeit foo(ex_grid, 3)
# 111 ms ± 820 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit np.roll(ex_grid, 1, axis=1)
# 13.8 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
...