PyTorch: повторная попытка пролистать график, но буферы уже освобождены - PullRequest
3 голосов
/ 29 февраля 2020

Моя модель:

class BaselineModel(nn.Module):
    def __init__(self, feature_dim=5, hidden_size=5, num_layers=2, batch_size=32):
        super(BaselineModel, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(input_size=feature_dim,
                            hidden_size=hidden_size, num_layers=num_layers)

    def forward(self, x, hidden):
        lstm_out, hidden = self.lstm(x, hidden)
        return lstm_out, hidden

    def init_hidden(self, batch_size):
        hidden = Variable(next(self.parameters()).data.new(
            self.num_layers, batch_size, self.hidden_size))
        cell = Variable(next(self.parameters()).data.new(
            self.num_layers, batch_size, self.hidden_size))
        return (hidden, cell)

Моя тренировка l oop выглядит следующим образом:


    for epoch in range(250):
        hidden = model.init_hidden(13)
        # hidden = (torch.zeros(2, 13, 5),
        #           torch.zeros(2, 13, 5))
        # model.hidden = hidden
        for i, data in enumerate(train_loader):
            inputs = data[0]
            outputs = data[1]

            print('inputs',  inputs.size())
            # print('outputs', outputs.size())

            optimizer.zero_grad()
            model.zero_grad()

            # print('inputs', inputs)
            pred, hidden = model(inputs, hidden)

            loss = loss_fn(pred[0], outputs)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

Кажется, я прошел первую эпоху, затем вижу эту ошибку: RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Я видел сообщения об этом, но в моем случае я заново инициализирую свой hidden для каждой партии.

1 Ответ

1 голос
/ 01 марта 2020

model.init_hidden(13) должно быть в партии l oop, а не в эпоху l oop

...