Я обучаю простую модель LSTM, однако pytorch выдает мне ошибку, говоря, что мне нужно установить retain_graph = True. Однако для обучения модели требуется больше времени, и я не думаю, что мне нужно это делать.
class SequenceModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(input_size = 3, hidden_size = 3, bidirectional=False)
self.hidden = (torch.randn(1, 1, 3).double(), torch.randn(1, 1, 3).double())
def forward(self,x):
lstm_out, self.hidden = self.lstm(x.view(-1, 1, 3),self.hidden)
return lstm_out
def loss(self,logits,labels):
return F.cross_entropy(logits, labels)
model = SequenceModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model = model.double()
model.train()
epochs = 1000
for epoch in tqdm(range(epochs)):
optimizer.zero_grad()
logits = model(inputs)
logits = logits.reshape(-1,3)
loss = model.loss(logits,outputs.long())
loss.backward()
optimizer.step()
Я получаю ошибку:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Однако я не хочу устанавливать для retain_graph значение True.