Как передать указатель массива на функцию Numba? - PullRequest
1 голос
/ 29 апреля 2020

Я хотел бы создать скомпилированную Numba функцию, которая принимает указатель или адрес памяти массива в качестве аргумента и выполняет вычисления на нем, например, изменяет базовые данные.

Чисто python версия для иллюстрации выглядит следующим образом:

import ctypes
import numba as nb
import numpy as np

arr = np.arange(5).astype(np.double)  # create arbitrary numpy array


def modify_data(addr):
    """ a function taking the memory address of an array to modify it """
    ptr = ctypes.c_void_p(addr)
    data = nb.carray(ptr, arr.shape, dtype=arr.dtype)
    data += 2

addr = arr.ctypes.data
modify_data(addr)
arr
# >>> array([2., 3., 4., 5., 6.])

Как видно из примера, массив arr был изменен без явной передачи его функции. В моем случае форма и тип d массива известны и всегда будут оставаться неизменными, что должно упростить интерфейс.

1. Попытка: наивное джитинг

Теперь я попытался скомпилировать функцию modify_data, но не получилось. Моей первой попыткой было использование

shape = arr.shape
dtype = arr.dtype

@nb.njit
def modify_data_nb(ptr):
    data = nb.carray(ptr, shape, dtype=dtype)
    data += 2


ptr = ctypes.c_void_p(addr)
modify_data_nb(ptr)   # <<< error

Это не удалось с cannot determine Numba type of <class 'ctypes.c_void_p'>, т. Е. Он не знает, как интерпретировать указатель.

2. Попытка: явные типы

Я пытался поставить явные типы

arr_ptr_type = nb.types.CPointer(nb.float64)
shape = arr.shape

@nb.njit(nb.types.void(arr_ptr_type))
def modify_data_nb(ptr):
    """ a function taking the memory address of an array to modify it """
    data = nb.carray(ptr, shape)
    data += 2

, но это не помогло. Он не выдает никаких ошибок, но я не знаю, как вызвать функцию modify_data_nb. Я попробовал следующие опции

modify_data_nb(arr.ctypes.data)
# TypeError: No matching definition for argument type(s) int64

ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject

ptr = ctypes.c_void_p(arr.ctypes.data)
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject

Есть ли способ получить правильный формат указателя из arr, чтобы я мог передать его скомпилированной в Numba функции modify_data_nb? Альтернативно, есть ли другой способ передачи ячейки памяти в функцию.

3. Попытка: используя scipy.LowLevelCallable

Я добился определенного прогресса, используя scipy.LowLevelCallable и его маги c:

arr = np.arange(3).astype(np.double)
print(arr)
# >>> array([0., 1., 2.])

# create the function taking a pointer
shape = arr.shape
dtype = arr.dtype

@nb.cfunc(nb.types.void(nb.types.CPointer(nb.types.double)))
def modify_data(ptr):
    data = nb.carray(ptr, shape, dtype=dtype)
    data += 2

modify_data_llc = LowLevelCallable(modify_data.ctypes).function    

# create pointer to array
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))

# call the function only with the pointer
modify_data_llc(ptr)

# check whether array got modified
print(arr)
# >>> array([2., 3., 4.])

Теперь я могу вызвать функцию для доступа к массиву, но это функция больше не является функцией Numba. В частности, его нельзя использовать в других функциях Numba.

1 Ответ

1 голос
/ 01 мая 2020

Благодаря замечательному @stuartarchibald у меня теперь есть рабочее решение:

import ctypes
import numba as nb
import numpy as np

arr = np.arange(5).astype(np.double)  # create arbitrary numpy array
print(arr)

@nb.extending.intrinsic
def address_as_void_pointer(typingctx, src):
    """ returns a void pointer from a given memory address """
    from numba.core import types, cgutils
    sig = types.voidptr(src)

    def codegen(cgctx, builder, sig, args):
        return builder.inttoptr(args[0], cgutils.voidptr_t)
    return sig, codegen

addr = arr.ctypes.data

@nb.njit
def modify_data():
    """ a function taking the memory address of an array to modify it """
    data = nb.carray(address_as_void_pointer(addr), arr.shape, dtype=arr.dtype)
    data += 2

modify_data()
print(arr)

Ключ - новая функция address_as_void_pointer, которая превращает адрес памяти (заданный как int) в указатель, который может использоваться numba.carray.

...