Как заставить функцию Cython принимать ввод с плавающей запятой или двойной массив? - PullRequest
0 голосов
/ 04 мая 2018

Предположим, у меня есть следующая (MCVE ...) функция Cython

cimport cython

from scipy.linalg.cython_blas cimport dnrm2


cpdef double func(int n, double[:] x):
   cdef int inc = 1
   return dnrm2(&n, &x[0], &inc)

Тогда я не могу вызвать его для np.float32 массива x.

Как я могу заставить func принять double[:] или float[:] и позвонить dnrm2 или snrm2 в качестве альтернативы? Единственное решение, которое у меня есть в настоящее время, - это две функции, которые создают огромное количество дублирующегося кода.

1 Ответ

0 голосов
/ 04 мая 2018

Вы можете использовать плавленый тип. Обратите внимание, что нижеприведенное ниже не компилируется в моей системе, потому что ddot и sdot, по-видимому, требуют 5 параметров:

# cython: infer_types=True
cimport cython

from scipy.linalg.cython_blas cimport ddot, sdot

ctypedef fused anyfloat:
   double
   float

cpdef anyfloat func(int n, anyfloat[:] x):
   cdef int inc = 1
   if anyfloat is double:
      return ddot(&n, &x[0], &inc)
   else:
      return sdot(&n, &x[0], &inc)
...