Как скомпилировать numba jit'ed функцию с переменным типом ввода? - PullRequest
2 голосов
/ 23 апреля 2019

Скажем, у меня есть функция, которая может принимать как int, так и None тип в качестве входного аргумента

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}


@nb.jit("f8(i8)", **jitkw)
def get_random(seed=None):
    np.random.seed(None)
    out = np.random.normal()
    return out

Я хочу, чтобы функция просто возвращала нормально распределенное случайное число.Если я хочу воспроизводимых результатов, seed должен быть int.

get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327

Если я хочу случайные числа, seed следует оставить как None.Однако, если я не передам аргумент (так по умолчанию seed = None) или явно передам seed=None, тогда numba вызовет TypeError

get_random()
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
get_random(None)
>>> TypeError: No matching definition for argument type(s) omitted(default=None)

Как я могу написать функцию, все еще объявляяподпись и использование режима nopython для такого сценария?

Моя версия numba - 0.43.1

1 Ответ

2 голосов
/ 23 апреля 2019

Первая проблема заключается в том, что numba в режиме nopython принимает только (начиная с версии 0.43.1) np.random.seed: только с целочисленным аргументом .

Так что, к сожалению, вы не можете перейти на None.


Вторая проблема заключается в том, что (насколько я знаю) нет «единой» сигнатуры, которая говорит numba, как обращаться с пропущенными значениями, однако вы можете использовать две сигнатуры (да, это очень многословно):

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}

@nb.jit(
    [nb.types.float64(nb.types.misc.Omitted(None)), 
     nb.types.float64(nb.types.int64)], 
    **jitkw)
def get_random(seed=None):
    return np.random.normal()

Просто краткое объяснение двух частей подписи:

  • nb.types.float64(nb.types.misc.Omitted(None)) указывает numba использовать None в качестве типа по умолчанию, если аргумент опущен
  • и nb.types.float64(nb.types.int64) - это подпись, которая ожидает целое число.

Лично я бы не стал указывать подпись и просто позволил numba выяснить это. Явные подписи редко стоят в numba, и чаще всего они не приводят к более медленному и менее гибкому коду.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...