Изменения Numba изменяются в результате добавления определенного типа 0 к локальной переменной - PullRequest
5 голосов
/ 17 февраля 2020

Рассмотрим следующую функцию для вычисления количества шагов для заданного входа для задачи 3 n + 1:

def num_steps(b, steps):
    e = b
    d = 0
    while True:
        if e == 1:
            d += steps[e]
            return d
        if e % 2 == 0:
            e //= 2
        else:
            e = 3*e + 1
        d += 1

Здесь steps существует для учета результата, но ради этого вопроса, мы просто отметим, что, пока steps[1] == 0, это не должно иметь никакого эффекта, поскольку в этом случае эффект d += steps[e] заключается в добавлении 0 к d. Действительно, следующий пример дает ожидаемый результат:

import numpy as np

steps = np.array([0, 0, 0, 0])
print(num_steps(3, steps))  # Prints 7

Если, однако, мы JIT скомпилируем метод, используя numba.jit (или njit), мы больше не получим правильный результат:

import numpy as np
from numba import jit

steps = np.array([0, 0, 0, 0])
print(jit(num_steps)(3, steps))  # Prints 0

Если мы удалим кажущееся избыточным d += steps[e] до компиляции метода, мы получим правильный результат. Мы могли бы даже вставить print(steps[e]) до d += steps[e] и увидеть, что значение равно 0. Я также могу переместить d += 1 в верхнюю часть l oop (и вместо этого инициализировать d = -1), чтобы получить что-то это также работает в случае с Numba.

Это происходит с Numba 0.48.0 (llvmlite 0.31.0) на Python 3.8 (самые последние версии доступны через стандартный канал conda).

1 Ответ

2 голосов
/ 17 февраля 2020

Для меня это похоже на ошибку, что-то с приращением на месте с steps[e]. Если вы установите parallel=True, то там, где Numba падает. Вы можете создать проблему в репозитории Numba github, возможно, разработчики могут объяснить это.

Если я переписываю функцию, чтобы избежать этого окончательного увеличения на месте, она работает для меня:

@numba.njit
def numb_steps(b, steps):

    e = b    
    d = 0

    while True:

        if e == 1:
            return d + steps[e]

        if e % 2 == 0:
            e //= 2
        else:
            e = 3*e + 1

        d += 1

С:

python                    3.7.6
numba                     0.47.0
...