Проблема с Numba jit: «Ошибка ввода» и «Все шаблоны отклонены с / без литералов» - PullRequest
2 голосов
/ 01 июля 2019

Я реализую программу, которая решает дифференциальное уравнение в Python 3.7.3, и есть одна функция, которую я просто не могу скомпилировать с Numba.Самая последняя его версия:

import numpy as np
from numba import jit, uint16, complex128, prange

# Here is the setup of the program, as well as variable initialization

@jit((complex128[:, :, :], uint16, complex128[:, :], complex128[:, :], complex128[:, :]), nopython=True)
def upd_x(rhs: np.ndarray, m: int, s: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
    x = np.zeros((3, m, m//2+1))
    x[2] = s*(1-a*(rhs[0]+rhs[1]))
    for i in range(2):
        x[i] = a*(rhs[i]+b*x[2])
    return x

Что нужно сделать, это взять «правую часть» (rhs) уравнения и обновить x (x имеет3 компонента, которые являются реальными полями, и код «обновляет» его в пространстве Фурье, поэтому последняя ось равна m//2+1 вместо m) с помощью метода дополнения Шура.Когда я запустил код, я получил следующее сообщение:

Traceback (most recent call last):
  File "C:/Users/Username/Desktop/Program/Program.py", line 95, in <module>
    @jit((complex128[:, :, :], uint16, complex128[:, :], complex128[:, :], complex128[:, :]), nopython=True)
  File "C:\Program Files\Python37\lib\site-packages\numba\decorators.py", line 186, in wrapper
    disp.compile(sig)
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "C:\Program Files\Python37\lib\site-packages\numba\dispatcher.py", line 659, in compile
    cres = self._compiler.compile(args, return_type)
  File "C:\Program Files\Python37\lib\site-packages\numba\dispatcher.py", line 83, in compile
    pipeline_class=self.pipeline_class)
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 955, in compile_extra
    return pipeline.compile_extra(func)
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 377, in compile_extra
    return self._compile_bytecode()
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 886, in _compile_bytecode
    return self._compile_core()
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 873, in _compile_core
    res = pm.run(self.status)
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 254, in run
    raise patched_exception
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 245, in run
    stage()
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 501, in stage_nopython_frontend
    self.locals)
  File "C:\Program Files\Python37\lib\site-packages\numba\compiler.py", line 1105, in type_inference_stage
    infer.propagate()
  File "C:\Program Files\Python37\lib\site-packages\numba\typeinfer.py", line 915, in propagate
    raise errors[0]
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 3d, C), Literal[int](2), array(complex128, 2d, C))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at C:/Users/User/Desktop/Program/Program.py (98)

File "Programa.py", line 98:
def upd_x(rhs, m, s, a, b):
    <source elided>
    x = np.zeros((3, m, m//2+1))
    x[2] = s*(1-a*(rhs[0]+rhs[1]))
    ^

Я не понимаю, почему в сообщении об ошибке указывается, что тип переменной не поддерживается, и я также не знаю, что мне нужноисправлять.Версии, которые я использую: numba == 0.44.1, numpy == 1.16.1.

Большое спасибо.

1 Ответ

1 голос
/ 01 июля 2019

Похоже, что Numba не смог определить тип вывода x, поэтому я добавил dtype к x. Затем вы сталкиваетесь со смешиванием np.int64 и uint16 с размерами аргументов до np.zeros, поскольку 3 интерпретируется как i64. Таким образом, будет скомпилировано следующее:

import numpy as np
from numba import jit, uint16, complex128, prange

# Here is the setup of the program, as well as variable initialization

@jit(complex128[:,:,:](complex128[:, :, :], uint16, complex128[:, :], complex128[:, :], complex128[:, :]), nopython=True)
def upd_x(rhs: np.ndarray, m: int, s: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
    mx = np.int64(m)
    x = np.zeros((3, mx, mx//2+1), dtype=np.complex128)
    x[2] = s*(1-a*(rhs[0]+rhs[1]))
    for i in range(2):
        x[i] = a*(rhs[i]+b*x[2])
    return x

Кроме того, обратите внимание, что я добавил тип возвращаемого значения для подписи, передаваемой @jit, хотя я считаю, что в этом нет необходимости.

И поэтому я использую входные данные:

m = 4
x = np.zeros((3, m, m//2+1), dtype=np.complex128) + 2 + 2j
y = np.zeros((m, m//2 + 1 ), dtype=np.complex128) + 1 + 1j

upd_x(x, np.uint16(m), y, y, y)

и это возвращает что-то разумное, я думаю.

...