Я пытаюсь закодировать метод наименьших квадратов в Cython для целей обучения.У меня работает эта базовая версия:
import cython
import numpy as np
from scipy.linalg import inv
cimport numpy as np
def ols_c(np.ndarray x, np.ndarray y):
cdef int nrowx = x.shape[0]
cdef int ncolx = x.shape[1]
cdef np.ndarray beta = np.zeros([ncolx,1], dtype=float)
cdef np.ndarray a1 = np.zeros([ncolx, ncolx], dtype=float)
cdef np.ndarray a2 = np.zeros([ncolx, nrowx], dtype=float)
a1 = inv(np.dot(x.T,x))
a2 = np.dot(a1,x.T)
beta = np.dot(a2,y)
return(beta)
, которая немного медленнее, чем эта версия Numpy:
import numpy as np
from scipy.linalg import inv
def ols(x,y):
a1 = inv(np.dot(x.T,x))
a2 = np.dot(a1,x.T)
beta = np.dot(a2,y)
return(beta)
Я думаю, это вероятно из-за неэффективной индексации массива.Следуя инструкциям в Интернете, я изменил базовую версию Cython следующим образом:
import cython
import numpy as np
from scipy.linalg import inv
cimport numpy as np
DTYPE = np.float
ctypedef np.float_t DTYPE_t
def ols_c(np.ndarray[DTYPE_t, ndim=2] x, np.ndarray[DTYPE_t, ndim=1] y):
cdef int nrowx = x.shape[0]
cdef int ncolx = x.shape[1]
cdef np.ndarray[DTYPE_t, ndim=1] beta = np.zeros([ncolx,1], dtype=float)
cdef np.ndarray[DTYPE_t, ndim=2] a1 = np.zeros([ncolx, ncolx], dtype=float)
cdef np.ndarray[DTYPE_t, ndim=2] a2 = np.zeros([ncolx, nrowx], dtype=float)
a1 = inv(np.dot(x.T,x))
a2 = np.dot(a1,x.T)
beta = np.dot(a2,y)
return(beta)
Но теперь она не работает, я получаю следующее сообщение об ошибке:
ValueError: Buffer has wrong number of dimensions (expected 1, got 2)
Что вызываетэта ошибка?У меня также есть некоторые другие вопросы:
Что на самом деле делают эти 2 строки?
DTYPE = np.float
ctypedef np.float_t DTYPE_t
Кроме того, если я правильно понимаю, введите этот cdef np.ndarray [DTYPE_t, ndim = 2] x= np.zeros ([ncol, nrow], dtype = float) создает двумерный массив x с числом столбцов, равным ncol, и строкой, равной nrow, которые содержат числа с плавающей точкой.Но что на самом деле делает [DTYPE_t, ndim = 2]?Я не нашел никакой документации по этому вопросу.
Заранее спасибо за ваши ответы!
РЕДАКТИРОВАТЬ: похоже, если я заменю DTYPE_t на double и прокомментирую эти две строки:
DTYPE = np.float
ctypedef np.float_t DTYPE_t
HОднако выполнение все еще идет медленно.Что я могу сделать, чтобы ускорить процесс?