Почему cython принимает float32, если я указал float64 dtype в функции? - PullRequest
0 голосов
/ 13 апреля 2020

Я определил эту функцию:

import numpy as np
cimport numpy as np

ctypedef np.float64_t dtype_t

to_n = lambda n: np.arange(1, n+1, dtype = 'int')
indexr = lambda i: i-1

def forward_substitution(np.ndarray[dtype_t, ndim = 2] L, np.ndarray[dtype_t, ndim = 1] b):
    cdef int n = len(b)
    cdef np.ndarray y = np.zeros(n)
    y[indexr(1)] = b[indexr(1)]/L[indexr(1), indexr(1)]
    cdef int i
    cdef int j
    cdef double suma
    for i in np.arange(2, n+1):
        suma = 0
        for j in to_n(i-1):
            suma = suma + L[indexr(i), indexr(j)]*y[indexr(j)]
        y[indexr(i)] = (b[indexr(i)] - suma)/L[indexr(i), indexr(i)]

    return(y)

Я сохранил ее в файле с именем forward_substitution.pyx, а затем создал этот файл:

import numpy as np
from setuptools import setup
from Cython.Build import cythonize

setup(
    ext_modules=cythonize("forward_substitution.pyx")
)

с именем setup_forward_substitution.py.

Я запустил на терминале

python3 setup_forward_substitution.py build_ext --inplace

и изменил имя .so файла на forward_substitution.so.

Затем в Spyder я установил рабочий каталог, в котором были эти файлы нашел и попробовал

from forward_subtitution import forward_substitution

Я пытался с этими матрицами:

L = np.array([[3, 0, 0, 0], 
              [2, -3, 0, 0], 
              [1, 0, 5, 0], 
              [0, 2, 4, -3]]).astype('float64')
b = np.array([6, 7, -8, -3]).astype('float64')

forward_substitution(L, b)

и это выдает эту ошибку:

Traceback (most recent call last):

  File "<ipython-input-196-59032bf074d0>", line 1, in <module>
    forward_substitution(L, b)

  File "forward_substitution.pyx", line 7, in forward_substitution.forward_substitution
    indexr = lambda i: i-1

ValueError: Buffer dtype mismatch, expected 'float' but got 'double'

Но если я изменю .astype (' float64 ') в .astype (' float32 ') не выдает ошибок, и результат в порядке.

Почему это происходит, если функция была определена с dtype np.float64_t?

...