Понимание Numba TypingError с помощью jit nopython - PullRequest
2 голосов
/ 24 апреля 2020

У меня проблемы с решением (возможно, основной c) ошибки Numba с использованием @jit(nopython=True). Это сводится к минимальному приведенному ниже примеру, который дает TypingError (полные журналы ниже). Если уместно, я использую Python 3.6.10 и Numba v0.49.0.

Ошибка возникает в строке d, создающей массив numpy (если я удаляю d и возвращаю c, работает нормально). Как я могу решить эту проблему?

from numba import jit
import numpy as np

n = 5
foo = np.random.rand(n,n)

@jit(nopython=True)
def bar(x):
    a = np.array([0,3,2])
    b = np.array([1,2,3])
    c = [x[i,j] for i,j in zip(a,b)]
    # print(c) # Un-commenting this line solves the issue‽ (per @Ethan's comment)
    d = np.array(c)
    return d

baz = bar(foo)

Полная ошибка:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-13-950d2be33d72> in <module>
     14     return d
     15 
---> 16 baz = bar(foo)
     17 print(baz)

~/miniconda3/envs/py3k/lib/python3.6/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
    399                 e.patch_message(msg)
    400 
--> 401             error_rewrite(e, 'typing')
    402         except errors.UnsupportedError as e:
    403             # Something unsupported is present in the user code, add help info

~/miniconda3/envs/py3k/lib/python3.6/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
    342                 raise e
    343             else:
--> 344                 reraise(type(e), e, None)
    345 
    346         argtypes = []

~/miniconda3/envs/py3k/lib/python3.6/site-packages/numba/core/utils.py in reraise(tp, value, tb)
     77         value = tp()
     78     if value.__traceback__ is not tb:
---> 79         raise value.with_traceback(tb)
     80     raise value
     81 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<intrinsic range_iter_len>) with argument(s) of type(s): (zip(iter(array(int64, 1d, C)), iter(array(int64, 1d, C))))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<intrinsic range_iter_len>)
[2] During: typing of call at <ipython-input-13-950d2be33d72> (9)


File "<ipython-input-13-950d2be33d72>", line 9:
def bar(x):
    a = np.array([0,3,2])
    ^

Обновление: Использование следующей функции вместо этого аналогичным образом завершается ошибкой (хотя print(c) трюк не помогает в этом случае):

@jit(nopython=True)
def bar(x):
    a = [0,3,2]
    b = [1,2,3]
    c = x[a, b]
    d = np.array(c)
    return d

1 Ответ

1 голос
/ 26 апреля 2020

Проблема с первой версией функции и тот факт, что добавление print(c) разрешает ее, для меня загадка. Предполагается, что Numba реализует zip (и, очевидно, в данном конкретном случае это может происходить при вызове строки print(c)), поэтому это похоже на ошибку.

Проблема со второй версией функция меньше загадки. Согласно текущей документации Numba :

Массивы поддерживают нормальную итерацию. Поддерживается полное индексирование и нарезка c. Также поддерживается подмножество расширенного индексирования: разрешен только один расширенный индекс, и он должен быть одномерным массивом (его также можно комбинировать с произвольным числом базовых c индексов).

Поскольку вы пытаетесь использовать два расширенных индекса, a и b, в строке c = x[a, b], код не поддерживается Numba. В самом деле, это то, что говорит многословное сообщение об ошибке Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), tuple(array(int64, 1d, C) x 2)).

Если бы вместо этого мы написали c=x[a,2], то код работал бы в соответствии с обещанием Numba разрешить один расширенный индекс.

В целом я обнаружил, что самый безопасный способ использовать Numba это писать в зацикленном стиле без более продвинутых функций NumPy. Это немного прискорбно - поскольку это почти как если бы нам нужно было писать на диалекте C, а не на Python - но с положительной стороны это все же гораздо удобнее, чем на самом деле писать C.

In в этом ключе хорошо работает следующий код:

@jit(nopython=True)
def bar(x):
    a = np.array([0,3,2])
    b = np.array([1,2,3])
    c = np.empty(len(a))
    for i in range(len(a)):
        c[i] = x[a[i], b[i]]
    return c
...