Вы можете использовать np.prod внутри объединенной функции numba:
n = 3
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
flat = np.ravel(arr).tolist()
gen = [list(a) for a in product(flat, repeat=n)]
@jit(nopython=True, parallel=True)
def mtp(gen):
results = np.empty(len(gen))
for i in prange(len(gen)):
results[i] = np.prod(gen[i])
return results
В качестве альтернативы, вы можете использовать уменьшение, как показано ниже (спасибо @stuartarchibald за указание на это), хотя распараллеливание не будет работать ниже (по крайней мере, для numba 0.48):
import numpy as np
from itertools import product
from functools import reduce
from operator import mul
from numba import njit, prange
lst = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
arr = np.array(lst)
n = 3
flat = np.ravel(arr).tolist()
gen = np.array([list(a) for a in product(flat, repeat=n)])
@njit
def mul_wrapper(x, y):
return mul(x, y)
@njit
def mtp(gen):
results = np.empty(gen.shape[0])
for i in prange(gen.shape[0]):
results[i] = reduce(mul_wrapper, gen[i], None)
return results
print(mtp(gen))
Или потому, что внутри Numba есть немного волшебства c, которое обнаруживает замыкания, которые будут выходить из функций и компилировать их. (опять же благодаря @stuartarchibald), вы можете сделать это ниже:
@njit
def mtp(gen):
results = np.empty(gen.shape[0])
def op(x, y):
return mul(x, y)
for i in prange(gen.shape[0]):
results[i] = reduce(op, gen[i], None)
return results
Но опять же, параллель здесь не работает с нумба 0,48.
Примечание рекомендуемый подход от члена основной команды разработчиков - это первое решение, использующее np.prod
. Он может использоваться с флагом параллели и имеет более простую реализацию.