Я пытаюсь построить модель линейной регрессии и найти оптимальные значения с помощью оптимизатора fmin_cg
.
У меня есть две функции для этой работы.Первый linear_reg_cost
, который является функцией стоимости, и второй linear_reg_grad
, который является градиентом функции стоимости.Обе функции имеют одинаковый аргумент.
def hypothesis(x,theta):
return np.dot(x,theta)
Функция стоимости:
def linear_reg_cost(x_flatten, y, theta_flatten, lambda_, num_of_features,num_of_samples):
x = x_flatten.reshape(num_of_samples, num_of_features)
theta = theta_flatten.reshape(n,1)
loss = hypothesis(x,theta)-y
regularizer = lambda_*np.sum(theta[1:,:]**2)/(2*m)
j = np.sum(loss ** 2)/(2*m)
return j
Градиентная функция:
def linear_reg_grad(x_flatten, y, theta_flatten, lambda_, num_of_features,num_of_samples):
x = x_flatten.reshape(num_of_samples, num_of_features)
m,n = x.shape
theta = theta_flatten.reshape(n,1)
new_theta = np.zeros(shape=(theta.shape))
loss = hypothesis(x,theta)-y
gradient = np.dot(x.T,loss)
new_theta[0:,:] = gradient/m
new_theta[1:,:] = gradient[1:,:]/m + lambda_*(theta[1:,]/m)
return new_theta
и fmin_cg
:
theta = np.ones(n)
from scipy.optimize import fmin_cg
new_theta = fmin_cg(f=linear_reg_cost, x0=theta, fprime=linear_reg_grad,args=(x.flatten(), y, lambda_, m,n))
Примечание: я сглаживаю x
в качестве входных данных и извлекаю в функции стоимости и градиента в качестве матрицы.
ошибка вывода:
<ipython-input-98-b29c1b8f6e58> in linear_reg_grad(x_flatten, y, theta_flatten, lambda_, num_of_features, num_of_samples)
1 def linear_reg_grad(x_flatten, y, theta_flatten, lambda_,num_of_features, num_of_samples):
----> 2 x = x_flatten.reshape(num_of_samples, num_of_features)
3 m,n = x.shape
4 theta = theta_flatten.reshape(n,1)
5 new_theta = np.zeros(shape=(theta.shape))
ValueError: cannot reshape array of size 2 into shape (2,12)
Примечание: x.shape = (12,2)
, y.shape = (12,1)
, theta.shape = (2,)
.Так что num_of_features =2
и num_of_samples=12
.Но ошибка показывает, что мой ввод x
анализируется вместо theta
.Почему это происходит, даже когда я явно назначил args
в fmin_cg
?И как мне решить эту проблему?
Спасибо за любой совет