Ошибка в рекурсивной функции с Numba в режиме nopython - PullRequest
1 голос
/ 08 апреля 2019

Я хочу запустить рекурсивную функцию в Numba, используя режим nopython. До сих пор я получаю только ошибки. Это очень простой код, пользователь дает кортеж с менее чем пятью элементами, а затем функция создает другой кортеж с новым значением, добавленным к кортежу (в данном случае, число 3). Это повторяется до тех пор, пока длина финального кортежа не станет 5. По какой-то причине это не работает, не знаю почему.

@njit
def tup(a):
    if len(a) == 5:
        return a
    else:
        b = a + (3,)
        b = tup(b)
        return b

Например, если a = (0,1), я ожидаю, что конечный результат будет кортеж (0,1,3,3,3).

РЕДАКТИРОВАТЬ: Я использую Numba 0.41.0, и ошибка, которую я получаю, умирает ядро: «Ядро, кажется, умерло. Он перезапустится автоматически. '

Ответы [ 2 ]

1 голос
/ 08 апреля 2019

Есть несколько причин, почему вы не должны этого делать:

  • Как правило, это такой подход, который, скорее всего, будет быстрее в чистом Python, чем в украшенной нумбой функции.
  • Итерация будет проще и, вероятно, быстрее, однако имейте в виду, что объединение кортежей обычно является O(n) операцией, даже в numba. Таким образом, общая производительность функции будет O(n**2). Это можно улучшить, используя структуру данных, которая поддерживает добавление O(1), или структуру данных, которая поддерживает предварительное распределение размера. Или просто не используя «зацикленный» или «рекурсивный» подход.
  • Вы пробовали, что произойдет, если вы пропустите декоратор njit и передадите кортеж, содержащий 6 элементов? (подсказка: он достигнет предела рекурсии, потому что он никогда не выполнит конечное условие рекурсии).

На момент написания 0.43.1 Numba поддерживает только простые рекурсии, когда тип аргументов между рекурсиями не меняется. В вашем случае тип меняется, вы передаете tuple(int64 x 2), но рекурсивный вызов пытается передать tuple(int64 x 3), который является другим типом. Странно, но на моем компьютере это StackOverflow, что похоже на ошибку в numba.

Я бы предложил использовать это (без нумбы, без рекурсии):

def tup(a):
    if len(a) < 5:
        a += (3, ) * (5 - len(a))
    return a

, который также возвращает ожидаемый результат:

>>> tup((1,))
(1, 3, 3, 3, 3)
>>> tup((1, 2))
(1, 2, 3, 3, 3)
1 голос
/ 08 апреля 2019

Согласно этот список предложений в текущих выпусках:

Поддержка рекурсии в numba в настоящее время ограничена саморекурсией с явной аннотацией типа для функции.Это ограничение происходит из-за невозможности определить тип возврата рекурсивного вызова.

Итак, вместо этого попробуйте:

from numba import jit

@jit()
def tup(a:tuple) -> tuple:
    if len(a) == 5:
        return a

    return tup(a + (3,))

print(tup((0, 1)))

Чтобы проверить, работает ли это лучше для вас.

...