numba-безопасная версия itertools.combinsk? - PullRequest
4 голосов
/ 17 апреля 2020

У меня есть некоторый код, который проходит через большой набор itertools.combinations, который сейчас является узким местом в производительности. Я пытаюсь обратиться к numba @jit(nopython=True), чтобы ускорить его, но у меня возникают некоторые проблемы.

Во-первых, кажется, что numba сама не может обработать itertools.combinations, согласно этому небольшому примеру:

import itertools
import numpy as np
from numba import jit

arr = [1, 2, 3]
c = 2

@jit(nopython=True)
def using_it(arr, c):
    return itertools.combinations(arr, c)

for i in using_it(arr, c):
    print(i)

ошибка броска: numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend) Unknown attribute 'combinations' of type Module(<module 'itertools' (built-in)>)

После некоторого поиска в Google, Я обнаружил эту проблему с github , в которой спрашивающий предложил эту numba-безопасную функцию для вычисления перестановок:

@jit(nopython=True)
def permutations(A, k):
    r = [[i for i in range(0)]]
    for i in range(k):
        r = [[a] + b for a in A for b in r if (a in b)==False]
    return r

Используя это, я могу легко отфильтровать по комбинациям:

@jit(nopython=True)
def combinations(A, k):
    return [item for item in permutations(A, k) if sorted(item) == item]

Теперь я могу запустить эту функцию combinations без ошибок и получить правильный результат. Однако теперь с @jit(nopython=True) это значительно медленнее, чем без него. Выполнение этого теста синхронизации:

A = list(range(20))  # numba throws 'cannot determine numba type of range' w/o list
k = 2
start = pd.Timestamp.utcnow()
print(combinations(A, k))
print(f"took {pd.Timestamp.utcnow() - start}")

включается через 2,6 секунды с помощью numba @jit(nopython=True) декораторов и менее 1/000 секунды с ними закомментировано. Так что для меня это тоже не реально.

1 Ответ

0 голосов
/ 23 апреля 2020

В этом случае нечего выиграть с Numba, так как itertools.combinations записано в C.

Если вы хотите сравнить его, вот Numba / Python реализация того, что делает itertools.combinatiions:

@jit(nopython=True)
def using_numba(pool, r):
    n = len(pool)
    indices = list(range(r))
    empty = not(n and (0 < r <= n))

    if not empty:
        result = [pool[i] for i in indices]
        yield result

    while not empty:
        i = r - 1
        while i >= 0 and indices[i] == i + n - r:
            i -= 1
        if i < 0:
            empty = True
        else:
            indices[i] += 1
            for j in range(i+1, r):
                indices[j] = indices[j-1] + 1

            result = [pool[i] for i in indices]
            yield result

На моей машине это примерно в 15 раз медленнее, чем itertools.combinations. Получение перестановок и фильтрация комбинаций, безусловно, будет еще медленнее.

...