Я работаю с Numba (0,44) и Numpy в режиме nopython
.В настоящее время Numba не поддерживает функции агрегирования Numpy по произвольной оси, она поддерживает только вычисление этих агрегатов по всему массиву.Учитывая ситуацию, я решил взять трещину и создать несколько прокладок.
В коде:
np.min(array) # This works with Numba 0.44
np.min(array, axis = 0) # This does not work with Numba 0.44 (no axis argument allowed)
Вот пример прокладки, предназначенной для воспроизведения np.min(array)
:
import numpy as np
import numba
@numba.jit(nopython = True)
def npmin (X, axis = -1):
"""
Shim for broadcastable np.min().
Allows np.min(array), np.min(array, axis = 0), and np.min(array, axis = 1)
Note that the argument axis = -1 computes on the entire array.
"""
if axis == 0:
_min = np.sort(X.transpose())[:,0]
elif axis == 1:
_min = np.sort(X)[:,0]
else:
_min = np.sort(np.sort(X)[:,0])[0]
return _min
Без Numba, прокладка работает как положено иповторяет поведение np.min()
до двумерного массива.Обратите внимание, что я использую axis = -1
в качестве средства разрешения суммирования всего массива - поведение, аналогичное вызову np.min(array)
без аргумента axis
.
К сожалению, как только я добавляю Numba в микс, я получаю ошибку.Вот след:
Traceback (most recent call last):
File "shims.py", line 81, in <module>
_min = npmin(a)
File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 348, in _compile_for_args
error_rewrite(e, 'typing')
File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 315, in error_rewrite
reraise(type(e), e, None)
File "/usr/local/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function sort at 0x10abd5ea0>) with argument(s) of type(s): (array(int64, 2d, F))
* parameterized
In definition 0:
All templates rejected
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function sort at 0x10abd5ea0>)
[2] During: typing of call at shims.py (27)
File "shims.py", line 27:
def npmin (X, axis = -1):
<source elided>
if axis == 0:
_min = np.sort(X.transpose())[:,0]
^
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.
To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html
For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new
Я проверил, что все функции, которые я использую, и их соответствующие аргументы поддерживаются в Numba 0.44.Конечно, трассировка стека говорит проблема связана с моим вызовом np.sort(array)
, но я подозреваю, что это может быть проблема с типизацией, потому что функция может возвращать либо скаляр (без аргумента оси), либо 2D-массив(с аргументом оси).
Тем не менее, у меня есть несколько вопросов:
- Есть ли проблема с моей реализацией;Может ли кто-нибудь точно определить неподдерживаемую функцию, которую я использую, как предполагает трассировка стека?
- Или, скорее, это ошибка в Numba?
- В общем, эти типыпрокладки в настоящее время возможно с Numba (0,44)?