Потеря Pytorch инфан - PullRequest
       7

Потеря Pytorch инфан

0 голосов
/ 26 июня 2018

Я пытаюсь сделать простую линейную регрессию с 1 функцией.Это простая проблема «предсказать зарплату с учетом многолетнего опыта».NN обучается на многолетнем опыте (X) и зарплате (Y).По какой-то причине потеря взрывается и в конечном итоге возвращает inf или nan

У меня есть такой код:

import torch
import torch.nn as nn
import pandas as pd
import numpy as np

dataset = pd.read_csv('./salaries.csv')

x_temp = dataset.iloc[:, :-1].values
y_temp = dataset.iloc[:, 1:].values

X_train = torch.FloatTensor(x_temp)
Y_train = torch.FloatTensor(y_temp)

class Model(torch.nn.Module): 
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1,1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

model = Model()

loss_func = torch.nn.MSELoss(size_average=False)
optim = torch.optim.SGD(model.parameters(), lr=0.01)

#training 
for epoch in range(200):
    #calculate y_pred
    y_pred = model(X_train)

    #calculate loss
    loss = loss_func(y_pred, Y_train)
    print(epoch, "{:.2f}".format(loss.data))

    #backward pass + update weights
    optim.zero_grad()
    loss.backward()
    optim.step()


test_exp = torch.FloatTensor([[8.0]])
print("8 years experience --> ", model(test_exp).data[0][0].item())

Как я уже говорил, после начала обучения потеряочень большой и в конечном итоге показывает inf после 10-й эпохи.

Я подозреваю, что это может иметь какое-то отношение к тому, как я загружаю данные?Это то, что находится в salaries.csv файле:

Years Salary
1.1 39343
1.3 46205
1.5 37731
2   43525
2.2 39891
2.9 56642
3   60150
3.2 54445
3.2 64445
3.7 57189
3.9 63218
4   55794
4   56957
4.1 57081
4.5 61111
4.9 67938
5.1 66029
5.3 83088

Спасибо за вашу помощь

Ответы [ 2 ]

0 голосов
/ 10 июля 2019

Вот пример, как это все происходит.Вы можете попробовать запустить эту программу, которая в основном представляет сеть уровня r-deep.

import torch
import math
import matplotlib.pyplot as plt
def stat(t, p=True):
    m = t.mean()
    s = t.std()
    if p==True:
        print(f"MEAN: {m}, STD: {s}")
    return(m,s)

_m = []
_s = []

c = 100
r = 50# repeat steps
x = torch.randn(c)
m = torch.randn(c,c)#/math.sqrt(n)
stat(x)

for _ in range (0,r):
    x = m@x    
    _1, _2 = stat(x, False)
    _m.append(_1)
    _s.append(_2)


stat(x)

plt.plot(_m)
plt.plot(_s)
plt.legend(["mean","std"])
plt.show()

enter image description here

0 голосов
/ 29 июня 2018

Как только потеря становится равной inf после определенного прохода, ваша модель будет повреждена после обратного распространения.Вероятно, это происходит потому, что значения в столбце «Зарплата» слишком велики.попробуйте нормализовать оклады.

В качестве альтернативы, вы можете попытаться инициализировать параметры вручную (вместо того, чтобы инициализировать их случайным образом), позволяя смещенному члену быть средним из окладов, а наклон линии равен 0(например).Таким образом, исходная модель будет достаточно близка к оптимальному решению, чтобы потери не увеличивались.

...