Ошибка несоответствия формы с scipy.optimize.minimize для логистической регрессии - PullRequest
0 голосов
/ 05 октября 2018

Я прошёл курс ML Andrew Ng по ML и пытаюсь реализовать программы на python.Для второго упражнения по логистической регрессии я пытаюсь использовать scipy.optimize.minimize для оптимизации функции затрат.Мой код выглядит следующим образом.

import os
import numpy as np
from scipy.special import expit
from scipy import optimize

datafile1 = os.path.join('data','ex2data1.txt')
data1 = np.loadtxt(datafile1, delimiter=',')
exam_scores, results = data1[:, :2], data1[:, 2]

m, n = exam_scores.shape

exam_scores = np.concatenate([np.ones([m, 1]), exam_scores], axis=1)

def cost_function(x, y, theta):
    m = len(y)
    hypothesis = expit(np.dot(x, theta))
    term1 = -np.dot(y.T, np.log(hypothesis)) / m
    term2 = -np.dot((1 - y).T, np.log(1 - hypothesis)) / m
    cost = term1 + term2
    return cost

def gradient(x, y, theta):
    m = len(y)
    hypothesis = expit(np.dot(x, theta))
    return np.dot(hypothesis - y, x) / m

def minimize_cost(x, y, theta):
    output = optimize.minimize(cost_function, theta, args=(x, y),
                               jac=gradient, options={'maxiter':400})
    return output.fun, output.x

theta = np.zeros(n + 1)
theta, cost = minimize_cost(exam_scores, results, theta)

Это дает мне

<ipython-input-42-e2ba65cce1d8> in gradient(x, y, theta)
       9 def gradient(x, y, theta):
      10     m = len(y)
 ---> 11     hypothesis = expit(np.dot(x, theta))
      12     return np.dot(hypothesis - y, x) / m

      ValueError: shapes (3,) and (100,) not aligned: 3 (dim 0) != 100 (dim 0).

Однако форма theta и выходные данные функции gradient одинаковы, то есть theta.shape == gradient(exam_scores, results, theta).shape дает мне True.

Я не понимаю, почему функция градиента вызывает ValueError при вызове из minimize, поскольку сама по себе она дает ожидаемый результат.

Любойуказатели приветствуются.

PS Вот часть данных.

exam_scores[:5, :]
array([[34.62365962, 78.02469282],
       [30.28671077, 43.89499752],
       [35.84740877, 72.90219803],
       [60.18259939, 86.3085521 ],
       [79.03273605, 75.34437644]])

results.reshape(m, 1)[:5, :]
array([[0.],
       [0.],
       [0.],
       [1.],
       [1.]])

Редактировать: Добавлена ​​часть данных.

...