Pytorch - lstm выдает ошибку retain_graph - как мне обойти это? - PullRequest
0 голосов
/ 03 марта 2020

Я обучаю простую модель LSTM, однако pytorch выдает мне ошибку, говоря, что мне нужно установить retain_graph = True. Однако для обучения модели требуется больше времени, и я не думаю, что мне нужно это делать.

class SequenceModel(nn.Module):

def __init__(self):
    super().__init__()
    self.lstm = nn.LSTM(input_size = 3, hidden_size = 3, bidirectional=False)
    self.hidden = (torch.randn(1, 1, 3).double(), torch.randn(1, 1, 3).double())

def forward(self,x):
    lstm_out, self.hidden = self.lstm(x.view(-1, 1, 3),self.hidden)
    return lstm_out

def loss(self,logits,labels):
    return F.cross_entropy(logits, labels)

model = SequenceModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model = model.double()

model.train()
epochs = 1000
for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()


    logits = model(inputs)
    logits = logits.reshape(-1,3)
    loss = model.loss(logits,outputs.long())

    loss.backward() 
    optimizer.step()

Я получаю ошибку:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Однако я не хочу устанавливать для retain_graph значение True.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...