Привет всем, я пытаюсь создать модель с использованием класса PyTorch RNN
и обучить эту модель с использованием мини-пакетов. Мой набор данных - простой таймер ie (один вход один выход). Вот как выглядит моя модель:
class RNN_pytorch(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN_pytorch, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.rnn = nn.RNN(input_size, hidden_size, num_layers=1)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
batch_size = x.size(1)
# print(batch_size)
hidden = self.init_hidden(batch_size)
out, hidden = self.rnn(x, hidden)
# out = out.view(out.size(1), out.size(2))
print("Input linear : ", out.size())
out = self.linear(out)
return out, hidden
def init_hidden(self, batch_size):
hidden = torch.zeros(1, batch_size, self.hidden_size)
# print(hidden.size())
return hidden
Затем я обрабатываю свой набор данных и делю его следующим образом:
batch_numbers = 13
batch_size = int(len(train_signal[:-1])/batch_numbers)
print("Train sample total size =", len(train_signal[:-1]))
print("Number of batches = ", batch_numbers)
print("Size of batches = {} (train_size / batch_numbers)".format(batch_size))
train_signal_batched = train_signal[:-1].reshape(batch_numbers, batch_size, 1)
train_label_batched = train_signal[1:].reshape(batch_numbers, batch_size, 1)
print("X_train shape =", train_signal_batched.shape)
print("Y_train shape =", train_label_batched.shape)
Возвращается:
Train sample total size = 829439
Number of batches = 13
Size of batches = 63803 (train_size / batch_numbers)
X_train shape = (13, 63803, 1)
Y_train shape = (13, 63803, 1)
Пока так хорошо, но потом я пытаюсь обучить свою модель:
rnn_mod = RNN_pytorch(1, 16, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.RMSprop(rnn_mod.parameters(), lr=0.01)
n_epochs = 3
hidden = rnn_mod.init_hidden(batch_size)
for epoch in range(1, n_epochs):
for i, batch in enumerate(train_signal_batched):
optimizer.zero_grad()
x = torch.Tensor([batch]).float()
print("Input : ",x.size())
out, hidden = rnn_mod.forward(x, hidden)
print("Output : ",out.size())
label = torch.Tensor([train_label_batched[i]]).float()
print("Label : ", label.size())
loss = criterion(output, label)
print("Loss : ", loss)
loss.backward(retain_graph=True)
optimizer.step()
print("*", end="")
# if epoch % 100 == 0:
print("Step {} --- Loss {}".format(epoch, loss))
, что приводит к ошибке:
Input : torch.Size([1, 63803, 1])
Input linear : torch.Size([1, 63803, 16])
Output : torch.Size([1, 63803, 1])
Label : torch.Size([1, 63803, 1])
Loss : tensor(0.0051)
/home/kostia/.virtualenvs/machine-learning/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 63803, 1])) that is different to the input size (torch.Size([1, 1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
return F.mse_loss(input, target, reduction=self.reduction)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-217-d019358438ff> in <module>
17 loss = criterion(output, label)
18 print("Loss : ", loss)
---> 19 loss.backward(retain_graph=True)
20 optimizer.step()
21 print("*", end="")
~/.virtualenvs/machine-learning/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
116 products. Defaults to ``False``.
117 """
--> 118 torch.autograd.backward(self, gradient, retain_graph, create_graph)
119
120 def register_hook(self, hook):
~/.virtualenvs/machine-learning/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
91 Variable._execution_engine.run_backward(
92 tensors, grad_tensors, retain_graph, create_graph,
---> 93 allow_unreachable=True) # allow_unreachable flag
94
95
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Может кто-нибудь сказать мне, в чем здесь проблема, потому что я, честно говоря, не ' Понятия не имею?
Заранее спасибо