Множественная линейная регрессия с градиентным спуском - PullRequest
1 голос
/ 23 мая 2019

Halo,

Я новичок в машинном обучении и Python, и я хочу предсказать Продажи Kaggle House в наборе данных округа Кинг с моим градиентным спуском .

Я разделяю 70% (15k строк) тестирования и 30% (6k строк) тестирования, и я выбираю 5 функций из 19, но есть проблема с производительностью, алгоритм занял так много времени (более 11часов), 100% памяти и не удалось выполнить.

Это мой класс по градиентному спуску:

class GradientDescent:

    X_train = []
    Y_train = []
    X_test  = []
    Y_test  = []
    lr = 0
    max_iter = 0
    theta = 0

    def __init__(self, X_train,Y_train,X_test,Y_test, lr=0.01, max_iter=100):
        self.X_train = X_train
        self.Y_train = Y_train
        self.X_test  = X_test
        self.Y_test  = Y_test
        self.lr = lr
        self.max_iter = max_iter
        self.theta = np.random.randn(X_train.shape[1], 1)
        print(self.theta)

    def costFunction(self,theta,X,y):
        "1/2m * E(h0-y)**2"
        m = len(y)
        y_pred = X.dot(theta)
        cost = (1/2*m) * np.sum(np.square(y_pred-y))

        return cost


    def estimate(self):
        m = len(self.Y_train)

        mse_hist = np.zeros(self.max_iter)

        #theta_hist = np.zeros(max_iter)
        i = 0
        while i < self.max_iter or mse_hist[i] > 0.01:
            y_pred = np.dot(self.X_train,self.theta)

            error = y_pred-self.Y_train
            self.theta = self.theta - (1/m)*self.lr*(self.X_train.T.dot((error)))
            mse_hist[i] = self.costFunction(self.theta,self.X_train, self.Y_train)

            #print(mse_hist[i])
            i+=1            
        return (self.theta, mse_hist)


    def test(self):
        res = pd.DataFrame()
        for i,row in self.X_test.iterrows():
            price_pred = np.dot(row.values,self.theta)
            res = row
            res['price_actual'] = self.Y_test[i]
            res['price_predict'] = price_pred

        res['r2_score'] = r2_score(res['price_actual'].values, res['price_predict'])
        res.to_csv('output.csv')

Есть ли какой-нибудь совет, чтобы сделать его лучше?

1 Ответ

1 голос
/ 23 мая 2019

В целом код выглядит нормально, хотя я его не проверял.Единственная ошибка, которую я смог найти, заключалась в том, что вы не можете увеличивать i в цикле while, поэтому цикл никогда не завершается.

...