Мой алгоритм линейной регрессии не работает - PullRequest
0 голосов
/ 18 октября 2019

Я закодировал алгоритм линейной регрессии, но вес и смещение не усваивают правильные значения

Фактические данные тренировки генерируются из y = x1 + x2

, поэтому w1, w2, b должны быть1, 1, 0 соответственно

Но ни один из них не получает правильные значения после обучения

Я не знаю, что не так с моим кодом

Заранее спасибо :))

Вот мой код ======================== КОД =========================

машинное обучение, линейная регрессия, pytorch

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d

import torch
import torch.nn as nn

#%%
N = 10
x1 = np.arange(N)
x2 = x1
y = x1 + x2 + 2

fig = plt.figure()
ax = fig.add_subplot(111,
                     projection = '3d')
plt.plot(x1, x2, y, 'bo')

x1 = x1.reshape(-1,1)
x2 = x2.reshape(-1,1)
x = np.hstack((x1, x2))

x_data = torch.tensor(x, dtype = torch.float)
y_data = torch.tensor(y, dtype = torch.float)

x_data.cuda()
y_data.cuda()
#%%
class LR(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        pred = self.linear(x)
        return pred

model = LR(2,1)

epochs = 100000
lr = 0.00001
check_freq = 1000


criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = lr)


x1_axis = np.linspace(-4, 4, 100)
x2_axis = x1_axis

losses = []
for i in range(epochs):
    pred = model.forward(x_data)
    loss = criterion(pred, y_data)
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


    if i%check_freq == 0:
        print("epoch: ", i, "loss:", loss.item())
        [w, b] = model.parameters()

        w1 = w[0][0].item()
        w2 = w[0][1].item()

        b = b[0].item()

        z = b + w1*x1_axis + w2*x2_axis
        plt.plot(x1_axis, x2_axis, z, 'r')
        plt.xlabel('x')
        plt.ylabel('y') 



for param in model.parameters():
    print(param)

plt.figure()
plt.plot(range(epochs), losses)
...