LSTM - Прогнозирование одинаковых значений констант после нескольких строк - PullRequest
0 голосов
/ 02 августа 2020

Я создаю модуль, используя nn.LSTM для предиката того же самого, но он предсказывает те же значения после 5-10 прогнозов:

    def __init__(self, input_size=len(train_cols), hidden_size=40, output_size=2, num_layer=10):
        super(NET, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, num_layer)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.out(out[:, -1, :])
        return out
net = NET(output_size=1)
optimizer = torch.optim.Adam(net.parameters(), lr=0.08, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
loss_func = torch.nn.MSELoss()
test_x = test_x.reshape(-1, 1, len(train_cols))
test_x = torch.from_numpy(test_x)
test_xv = Variable(test_x)
for epoch in range(100):
    var_x = Variable(train_x).type(torch.FloatTensor)
    var_y = Variable(train_y).type(torch.FloatTensor)
    out = net(var_x)
    loss = loss_func(out.squeeze(), var_y.squeeze())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 5 == 0:
        print('Epoch: {}, Loss: {:.5f}'.format(epoch + 1, loss.data.numpy()))
        pred_test = net(test_xv)
        pred_test = pred_test.view(-1).data.numpy()
        viz.text(str(np.dstack((pred_test,test_y))), win='compare',
                 opts=dict(title='pred'))

Прогноз:

[ [[0,90698224 0,97657084] [1,3066663 0,98598385] [1,4000531 1,003238] [1,4139748 1,018212] [1,4166933 1,0115894] [1,4169703 0,99727374] [1,4168639 0,9989065] [1,4166994 0,9569663 0,9989065] [1,4166994 0,9569661] [1,416420822 0,9569661] [1,416420822 0,9569661] [1,416420822 0,9569661] [1,416420822] 0,9 1,4161557 1,0281057] [1,4160964 1,0196408] [1,4160475 1,0191873] [1,4160074 1,0168635] [1,4159743 1,0126513] [1,4159468 1,004374] [1,4159242 1,0187122] [1,4159056 1,0370985] [1,4158906153 1,067834 1,0370985] [1,4158908463 1,067899] 1,0370985] [1,4158906153 1,067934 1,07 ] [1,4158456 0,97619045] [1,4158409 0,97762746] [1,4158368 0,9567483] [1,4158336 0,93333334] [1,4158307 0,9560899]]]

Кто мне скажет почему?

...