Как использовать Numba вместе с functools.reduce () - PullRequest
2 голосов
/ 11 марта 2020

У меня есть следующий код, где я пытаюсь провести параллель с l oop, используя numba, functools.reduce() и mul:

import numpy as np
from itertools import product
from functools import reduce
from operator import mul
from numba import jit, prange

lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
n = 3
flat = np.ravel(arr).tolist()
gen = np.array([list(a) for a in product(flat, repeat=n)])

@jit(nopython=True, parallel=True)
def mtp(gen):
    results = np.empty(gen.shape[0])
    for i in prange(gen.shape[0]):
        results[i] = reduce(mul, gen[i], initializer=None)
    return results
mtp(gen)

Но это дает мне ошибку:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-503-cd6ef880fd4a> in <module>
     10         results[i] = reduce(mul, gen[i], initializer=None)
     11     return results
---> 12 mtp(gen)

~\Anaconda3\lib\site-packages\numba\dispatcher.py in _compile_for_args(self, *args, **kws)
    399                 e.patch_message(msg)
    400 
--> 401             error_rewrite(e, 'typing')
    402         except errors.UnsupportedError as e:
    403             # Something unsupported is present in the user code, add help info

~\Anaconda3\lib\site-packages\numba\dispatcher.py in error_rewrite(e, issue_type)
    342                 raise e
    343             else:
--> 344                 reraise(type(e), e, None)
    345 
    346         argtypes = []

~\Anaconda3\lib\site-packages\numba\six.py in reraise(tp, value, tb)
    666             value = tp()
    667         if value.__traceback__ is not tb:
--> 668             raise value.with_traceback(tb)
    669         raise value
    670 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function reduce>) with argument(s) of type(s): (Function(<built-in function mul>), array(int32, 1d, C), initializer=none)
 * parameterized
In definition 0:
    AssertionError: 
    raised from C:\Users\HP\Anaconda3\lib\site-packages\numba\parfor.py:4138
In definition 1:
    AssertionError: 
    raised from C:\Users\HP\Anaconda3\lib\site-packages\numba\parfor.py:4138
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function reduce>)
[2] During: typing of call at <ipython-input-503-cd6ef880fd4a> (10)


File "<ipython-input-503-cd6ef880fd4a>", line 10:
def mtp(gen):
    <source elided>
    for i in prange(gen.shape[0]):
        results[i] = reduce(mul, gen[i], initializer=None)
        ^

Я не уверен, где я ошибся. Кто-нибудь может указать мне правильное направление? Большое спасибо.

1 Ответ

2 голосов
/ 13 марта 2020

Вы можете использовать np.prod внутри объединенной функции numba:

n = 3
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
flat = np.ravel(arr).tolist()
gen = [list(a) for a in product(flat, repeat=n)]

@jit(nopython=True, parallel=True)
def mtp(gen):
    results = np.empty(len(gen))
    for i in prange(len(gen)):
        results[i] = np.prod(gen[i])
    return results

В качестве альтернативы, вы можете использовать уменьшение, как показано ниже (спасибо @stuartarchibald за указание на это), хотя распараллеливание не будет работать ниже (по крайней мере, для numba 0.48):

import numpy as np
from itertools import product
from functools import reduce
from operator import mul
from numba import njit, prange

lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
n = 3
flat = np.ravel(arr).tolist()
gen = np.array([list(a) for a in product(flat, repeat=n)])

@njit
def mul_wrapper(x, y):
    return mul(x, y)

@njit
def mtp(gen):
    results = np.empty(gen.shape[0])
    for i in prange(gen.shape[0]):
        results[i] = reduce(mul_wrapper, gen[i], None)
    return results

print(mtp(gen))

Или потому, что внутри Numba есть немного волшебства c, которое обнаруживает замыкания, которые будут выходить из функций и компилировать их. (опять же благодаря @stuartarchibald), вы можете сделать это ниже:

@njit
def mtp(gen):
    results = np.empty(gen.shape[0])
    def op(x, y):
        return mul(x, y)
    for i in prange(gen.shape[0]):
        results[i] = reduce(op, gen[i], None)
    return results

Но опять же, параллель здесь не работает с нумба 0,48.

Примечание рекомендуемый подход от члена основной команды разработчиков - это первое решение, использующее np.prod. Он может использоваться с флагом параллели и имеет более простую реализацию.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...