Как вычислять значения из функций, не возвращая их или не устанавливая их глобально (для numba.cuda)? - PullRequest
0 голосов
/ 13 июля 2020

Я пытаюсь запустить этот простой код на графическом процессоре CUDA. Модуль, который я использую для этого, - numba.cuda:

import numba
from numba import cuda

@numba.cuda.jit
def function_4(j, k):
    l = j + k
    return l

l = function_4(1, 2)
print(l)

Вывод:

Traceback (most recent call last):
  File "/home/amu/Desktop/RL_framework/help_functions/test2.py", line 9, in <module>
    l = function_4(1, 2)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 758, in __call__
    kernel = self.specialize(*args)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 769, in specialize
    kernel = self.compile(argtypes)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 785, in compile
    **self.targetoptions)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 57, in compile_kernel
    cres = compile_cuda(pyfunc, types.void, args, debug=debug, inline=inline)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 46, in compile_cuda
    locals={})
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 568, in compile_extra
    return pipeline.compile_extra(func)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 339, in compile_extra
    return self._compile_bytecode()
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 401, in _compile_bytecode
    return self._compile_core()
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 381, in _compile_core
    raise e
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 372, in _compile_core
    pm.run(self.state)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_machinery.py", line 341, in run
    raise patched_exception
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_machinery.py", line 332, in run
    self._runPass(idx, pass_inst, state)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_machinery.py", line 291, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_machinery.py", line 264, in check
    mangled = func(compiler_state)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/typed_passes.py", line 98, in run_pass
    raise_errors=self._raise_errors)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/typed_passes.py", line 70, in type_inference_stage
    infer.propagate(raise_errors=raise_errors)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/typeinfer.py", line 986, in propagate
    raise errors[0]
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No conversion from int64 to none for '$12return_value.4', defined at None

File "test2.py", line 7:
def function_4(j, k):
    <source elided>
    l = j + k
    return l
    ^

[1] During: typing of assignment at /home/amu/Desktop/RL_framework/help_functions/test2.py (7)

File "test2.py", line 7:
def function_4(j, k):
    <source elided>
    l = j + k
    return l
    ^

numba.cuda не поддерживает инструкцию return. Итак, как мне использовать функции для вычисления значений? Оператор global, похоже, тоже не поддерживается:

import numba
from numba import cuda

@numba.cuda.jit
def function_4(j, k):
    global l
    l = j + k

function_4(1, 2)
print(l)

Вывод:

Traceback (most recent call last):
  File "/home/amu/Desktop/RL_framework/help_functions/test.py", line 9, in <module>
    function_4(1, 2)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 758, in __call__
    kernel = self.specialize(*args)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 769, in specialize
    kernel = self.compile(argtypes)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 785, in compile
    **self.targetoptions)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 57, in compile_kernel
    cres = compile_cuda(pyfunc, types.void, args, debug=debug, inline=inline)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/cuda/compiler.py", line 46, in compile_cuda
    locals={})
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 568, in compile_extra
    return pipeline.compile_extra(func)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 339, in compile_extra
    return self._compile_bytecode()
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 401, in _compile_bytecode
    return self._compile_core()
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 381, in _compile_core
    raise e
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler.py", line 372, in _compile_core
    pm.run(self.state)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_machinery.py", line 341, in run
    raise patched_exception
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_machinery.py", line 332, in run
    self._runPass(idx, pass_inst, state)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_machinery.py", line 291, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/compiler_machinery.py", line 264, in check
    mangled = func(compiler_state)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/untyped_passes.py", line 86, in run_pass
    func_ir = interp.interpret(bc)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/interpreter.py", line 116, in interpret
    flow.run()
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/byteflow.py", line 107, in run
    runner.dispatch(state)
  File "/home/amu/anaconda3/lib/python3.7/site-packages/numba/core/byteflow.py", line 269, in dispatch
    raise UnsupportedError(msg, loc=self.get_debug_loc(inst.lineno))
numba.core.errors.UnsupportedError: Failed in nopython mode pipeline (step: analyzing bytecode)
Use of unsupported opcode (STORE_GLOBAL) found

File "test.py", line 7:
def function_4(j, k):
    <source elided>
    global l
    l = j + k
    ^

1 Ответ

2 голосов
/ 13 июля 2020

Ваш код должен выглядеть примерно так:

import numpy as np
import numba 
from numba import cuda 

@cuda.jit 
def function_4(i, j, k):
     i[0] = j[0] + k[0]

j = np.array([1],  dtype=np.int32)
k = np.array([2],  dtype=np.int32)
i = np.zeros_like(j)

function_4[1,1](i, j, k)
print(i[0])

[Запишите код, написанный на телефоне в зале вылета аэропорта, никогда не проверялся, используйте на свой страх и риск]

В основном все нужно было передавать как массивы с явными типами данных. Если вы собираетесь писать ядра, вам лучше начать с собственного диалекта CUDA C ++, который хорошо документирован, а затем вернуться к Numba, чего нет. Тогда все будет само собой разумеющимся

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...