Numpy агрегатная функция shims, типирование и np.sort () в Numba - PullRequest
1 голос
/ 30 июня 2019

Я работаю с Numba (0,44) и Numpy в режиме nopython.В настоящее время Numba не поддерживает функции агрегирования Numpy по произвольной оси, она поддерживает только вычисление этих агрегатов по всему массиву.Учитывая ситуацию, я решил взять трещину и создать несколько прокладок.

В коде:

np.min(array) # This works with Numba 0.44
np.min(array, axis = 0) # This does not work with Numba 0.44 (no axis argument allowed)

Вот пример прокладки, предназначенной для воспроизведения np.min(array):

import numpy as np
import numba

@numba.jit(nopython = True)
def npmin (X, axis = -1):
    """
    Shim for broadcastable np.min(). 
    Allows np.min(array), np.min(array, axis = 0), and np.min(array, axis = 1)
    Note that the argument axis = -1 computes on the entire array.
    """
    if axis == 0:
        _min = np.sort(X.transpose())[:,0]
    elif axis == 1:
        _min = np.sort(X)[:,0]
    else:
        _min = np.sort(np.sort(X)[:,0])[0]
    return _min

Без Numba, прокладка работает как положено иповторяет поведение np.min() до двумерного массива.Обратите внимание, что я использую axis = -1 в качестве средства разрешения суммирования всего массива - поведение, аналогичное вызову np.min(array) без аргумента axis.

К сожалению, как только я добавляю Numba в микс, я получаю ошибку.Вот след:

Traceback (most recent call last):
  File "shims.py", line 81, in <module>
    _min = npmin(a)
  File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 348, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 315, in error_rewrite
    reraise(type(e), e, None)
  File "/usr/local/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
    raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function sort at 0x10abd5ea0>) with argument(s) of type(s): (array(int64, 2d, F))
 * parameterized
In definition 0:
    All templates rejected
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function sort at 0x10abd5ea0>)
[2] During: typing of call at shims.py (27)


File "shims.py", line 27:
def npmin (X, axis = -1):
    <source elided>
    if axis == 0:
        _min = np.sort(X.transpose())[:,0]
        ^

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile

If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new

Я проверил, что все функции, которые я использую, и их соответствующие аргументы поддерживаются в Numba 0.44.Конечно, трассировка стека говорит проблема связана с моим вызовом np.sort(array), но я подозреваю, что это может быть проблема с типизацией, потому что функция может возвращать либо скаляр (без аргумента оси), либо 2D-массив(с аргументом оси).

Тем не менее, у меня есть несколько вопросов:

  • Есть ли проблема с моей реализацией;Может ли кто-нибудь точно определить неподдерживаемую функцию, которую я использую, как предполагает трассировка стека?
  • Или, скорее, это ошибка в Numba?
  • В общем, эти типыпрокладки в настоящее время возможно с Numba (0,44)?

1 Ответ

1 голос
/ 03 июля 2019

Вот альтернативная прокладка для 2d массивов:

@numba.jit(nopython=True)
def npmin2(X, axis=0):
    if axis == 0:
        _min = np.empty(X.shape[1])
        for i in range(X.shape[1]):
            _min[i] = np.min(X[:,i])
    elif axis == 1:
        _min = np.empty(X.shape[0])
        for i in range(X.shape[0]):
            _min[i] = np.min(X[i,:])

    return _min

хотя вам придется найти обходной путь для случая axis=-1, потому что он вернет скаляр, а другие аргументы вернут массивы, и Numba не сможет «объединить» возвращаемый тип во что-то последовательны.

Производительность, по крайней мере на моей машине, кажется примерно сопоставимой с простым вызовом эквивалентного np.min, иногда np.min быстрее, а иногда npmin2 выигрыш, в зависимости от размера входного массива и оси .

...