Забавное поведение с numba-guvectorized функциями с использованием argmax () - PullRequest
0 голосов
/ 23 ноября 2018

Рассмотрим следующий скрипт:

from numba import guvectorize, u1, i8
import numpy as np

@guvectorize([(u1[:],i8)], '(n)->()')
def f(x, res):
    res = x.argmax()

x = np.array([1,2,3],dtype=np.uint8)
print(f(x))
print(x.argmax())
print(f(x))

При запуске я получаю следующее:

4382569440205035030
2
2

Почему это происходит?Есть ли способ сделать это правильно?

1 Ответ

0 голосов
/ 27 ноября 2018

Python не имеет ссылок, поэтому res = ... на самом деле не присваивает выходной параметр, а вместо этого связывает имя res.Я полагаю, что res указывает на неинициализированную память, поэтому ваш первый запуск дает, казалось бы, случайное значение.

Numba работает с этим, используя синтаксис среза ([:]), который действительно мутирует res; вам также нужно объявитьтип в виде массива.Рабочая функция:

@guvectorize([(u1[:], i8[:])], '(n)->()')
def f(x, res):
    res[:] = x.argmax()
...