Проблема в реализации линейной регрессии - PullRequest
3 голосов
/ 09 февраля 2020

Я новичок в машинном обучении, и я пытался реализовать векторизованную линейную регрессию с нуля, используя numpy. Я попытался проверить реализацию, используя y = x. Но моя потеря увеличивается, и я не могу понять, почему. Было бы здорово, если бы кто-то мог указать, почему это происходит. Заранее спасибо!

import numpy as np

class LinearRegressor(object):
    def __init__(self, num_features):
        self.num_features = num_features
        self.w = np.random.randn(num_features, 1).astype(np.float32)
        self.b = np.array(0.0).astype(np.float32)

    def forward(self, x):
        return np.dot(x, self.w) + self.b

    @staticmethod
    def loss(y_pred, y_true):
        l = np.average(np.power(y_pred - y_true, 2)) / 2
        return l

    def calculate_gradients(self, x, y_pred, y_true):
        self.dl_dw = np.dot(x.T, y_pred - y_true) / len(x)
        self.dl_db = np.mean(y_pred - y_true)

    def optimize(self, step_size):
        self.w -= step_size*self.dl_dw
        self.b -= step_size*self.dl_db

    def train(self, x, y, step_size=1.0):
        y_pred = self.forward(x)
        l = self.loss(y_pred=y_pred, y_true=y)
        self.calculate_gradients(x=x, y_pred=y_pred, y_true=y)
        self.optimize(step_size=step_size)
        return l

    def evaluate(self, x, y):
        return self.loss(self.forward(x), y_true)

check_reg = LinearRegressor(num_features=1)
x = np.array(list(range(1000))).reshape(-1, 1)
y = x
losses = []
for iteration in range(100):
    loss = check_reg.train(x=x,y=y, step_size=0.001)
    losses.append(loss)
    if iteration % 1 == 0:
        print("Iteration: {}".format(iteration))
        print(loss)

Вывод

Iteration: 0
612601.7859402705
Iteration: 1
67456013215.98818
Iteration: 2
7427849474110884.0
Iteration: 3
8.179099502901393e+20
Iteration: 4
9.006330707513148e+25
Iteration: 5
9.917228672922966e+30
Iteration: 6
1.0920254505132042e+36
Iteration: 7
1.2024725981084638e+41
Iteration: 8
1.324090295064888e+46
Iteration: 9
1.4580083421516024e+51
Iteration: 10
1.60547085025467e+56
Iteration: 11
1.7678478362285333e+61
Iteration: 12
1.946647415292399e+66
Iteration: 13
2.1435307416407376e+71
Iteration: 14
2.3603265498975516e+76
Iteration: 15
2.599049318486855e+81
Iteration: 16
nan
Iteration: 17
nan
Iteration: 18
nan
Iteration: 19
nan
Iteration: 20
nan
Iteration: 21
nan
Iteration: 22
nan
Iteration: 23
nan
Iteration: 24
nan
Iteration: 25
nan
Iteration: 26
nan
Iteration: 27
nan
Iteration: 28
nan
Iteration: 29
nan
Iteration: 30
nan
Iteration: 31
nan
Iteration: 32
nan
Iteration: 33
nan
Iteration: 34
nan
Iteration: 35
nan
Iteration: 36
nan
Iteration: 37
nan
Iteration: 38
nan
Iteration: 39
nan
Iteration: 40
nan
Iteration: 41
nan
Iteration: 42
nan
Iteration: 43
nan
Iteration: 44
nan
Iteration: 45
nan
Iteration: 46
nan
Iteration: 47
nan
Iteration: 48
nan
Iteration: 49
nan
Iteration: 50
nan
Iteration: 51
nan
Iteration: 52
nan
Iteration: 53
nan
Iteration: 54
nan
Iteration: 55
nan
Iteration: 56
nan
Iteration: 57
nan
Iteration: 58
nan
Iteration: 59
nan
Iteration: 60
nan
Iteration: 61
nan
Iteration: 62
nan
Iteration: 63
nan
Iteration: 64
nan
Iteration: 65
nan
Iteration: 66
nan
Iteration: 67
nan
Iteration: 68
nan
Iteration: 69
nan
Iteration: 70
nan
Iteration: 71
nan
Iteration: 72
nan
Iteration: 73
nan
Iteration: 74
nan
Iteration: 75
nan
Iteration: 76
nan
Iteration: 77
nan
Iteration: 78
nan
Iteration: 79
nan
Iteration: 80
nan
Iteration: 81
nan
Iteration: 82
nan
Iteration: 83
nan
Iteration: 84
nan
Iteration: 85
nan
Iteration: 86
nan
Iteration: 87
nan
Iteration: 88
nan
Iteration: 89
nan
Iteration: 90
nan
Iteration: 91
nan
Iteration: 92
nan
Iteration: 93
nan
Iteration: 94
nan
Iteration: 95
nan
Iteration: 96
nan
Iteration: 97
nan
Iteration: 98
nan
Iteration: 99
nan

1 Ответ

1 голос
/ 09 февраля 2020

Нет ничего плохого в вашей реализации. Ваш размер шага слишком велик, чтобы сходиться. Вы подпрыгиваете вокруг гребня оптимизации к все большей и большей ошибке. enter image description here измените размер шага для этого:

loss = check_reg.train(x=x,y=y, step_size=0.000001)

, и вы получите:

Iteration: 0
58305.102166924036
Iteration: 1
25952.192344178206
Iteration: 2
11551.585414406314
Iteration: 3
5141.729521746186
Iteration: 4
2288.6353484460747
Iteration: 5
1018.6952280352172
Iteration: 6
453.4320214875039
Iteration: 7
201.82728832044089
Iteration: 8
89.83519431606754
Iteration: 9
39.98665864625944
Iteration: 10
17.798416262435936
Iteration: 11
7.92229454258205
Iteration: 12
3.526272890501929
Iteration: 13
1.5696002444816197
Iteration: 14
0.6986516574778796
Iteration: 15
0.3109875219688626
Iteration: 16
0.13843156434074647
Iteration: 17
0.061616235257299326
Iteration: 18
0.027424318402401473
Iteration: 19
0.012205888201891543
Iteration: 20
0.005434012356344396
Iteration: 21
0.0024188644277583476
Iteration: 22
0.0010770380211645404
Iteration: 23
0.0004796730257022216
Iteration: 24
0.00021339295719587025
Iteration: 25
9.499628306355218e-05
Iteration: 26
4.244764386691682e-05
Iteration: 27
1.8965112443214162e-05
Iteration: 28
8.56069334821767e-06
Iteration: 29
3.848135476439999e-06
Iteration: 30
1.7367004907528985e-06
Iteration: 31
8.07976330965736e-07
Iteration: 32
4.0167090640020525e-07
Iteration: 33
2.253979336583221e-07
Iteration: 34
1.5365746125585947e-07
Iteration: 35
1.2480275459766612e-07
Iteration: 36
1.1147859663321005e-07
Iteration: 37
1.0288427880059631e-07
Iteration: 38
1.0036079530613815e-07
Iteration: 39
9.901975516098116e-08
Iteration: 40
9.901971962009025e-08
Iteration: 41
9.901968407922984e-08
Iteration: 42
9.901964853839991e-08
Iteration: 43
9.901961299760048e-08
Iteration: 44
9.901957745683155e-08
Iteration: 45
9.90195419160931e-08
Iteration: 46
9.901950637538515e-08
Iteration: 47
9.90194708347077e-08
Iteration: 48
9.901943529406073e-08
Iteration: 49
9.901939975344426e-08
Iteration: 50
9.901936421285829e-08
Iteration: 51
9.90193286723028e-08
Iteration: 52
9.901929313177781e-08
Iteration: 53
9.901925759128331e-08
Iteration: 54
9.901922205081931e-08
Iteration: 55
9.90191865103858e-08
Iteration: 56
9.901915096998278e-08
Iteration: 57
9.901911542961026e-08
Iteration: 58
9.901907988926822e-08
Iteration: 59
9.901904434895669e-08
Iteration: 60
9.901900880867564e-08
Iteration: 61
9.901897326842509e-08
Iteration: 62
9.901893772820503e-08
Iteration: 63
9.901890218801546e-08
Iteration: 64
9.901886664785639e-08
Iteration: 65
9.901883110772781e-08
Iteration: 66
9.901879556762973e-08
Iteration: 67
9.901876002756213e-08
Iteration: 68
9.901872448752503e-08
Iteration: 69
9.901868894751843e-08
Iteration: 70
9.901865340754231e-08
Iteration: 71
9.901861786759669e-08
Iteration: 72
9.901858232768157e-08
Iteration: 73
9.901854678779693e-08
Iteration: 74
9.901851124794279e-08
Iteration: 75
9.901847570811914e-08
Iteration: 76
9.901844016832599e-08
Iteration: 77
9.901840462856333e-08
Iteration: 78
9.901836908883116e-08
Iteration: 79
9.901833354912948e-08
Iteration: 80
9.90182980094583e-08
Iteration: 81
9.901826246981762e-08
Iteration: 82
9.901822693020742e-08
Iteration: 83
9.901819139062772e-08
Iteration: 84
9.901815585107851e-08
Iteration: 85
9.90181203115598e-08
Iteration: 86
9.901808477207157e-08
Iteration: 87
9.901804923261384e-08
Iteration: 88
9.90180136931866e-08
Iteration: 89
9.901797815378986e-08
Iteration: 90
9.901794261442361e-08
Iteration: 91
9.901790707508786e-08
Iteration: 92
9.901787153578259e-08
Iteration: 93
9.901783599650782e-08
Iteration: 94
9.901780045726355e-08
Iteration: 95
9.901776491804976e-08
Iteration: 96
9.901772937886647e-08
Iteration: 97
9.901769383971367e-08
Iteration: 98
9.901765830059137e-08
Iteration: 99
9.901762276149956e-08

Надеюсь, это поможет!

...