Ошибка Pytorch "Повторная попытка прокрутки графика назад" - PullRequest
0 голосов
/ 27 мая 2020

У меня проблема с pytorch. Когда я запускаю свой код, на шаге loss.backward () я получаю сообщение об ошибке, в котором говорится, что torch пытается вернуться назад во второй раз. Когда я запускаю

import torch

class module(torch.nn.Module):
    def __init__(self):
        super().__init__()

        bond_dim = 4
    tensor_num_feats = 8
    self.alpha = torch.rand(bond_dim)

    tensor_core = torch.rand((bond_dim, tensor_num_feats, bond_dim), requires_grad=True)
    tensor_core = torch.nn.Parameter(tensor_core)
    self.register_parameter('tensor',tensor_core)
    eye = torch.eye(bond_dim,bond_dim, requires_grad = False, dtype=torch.float)
    batch_core = torch.zeros(bond_dim, tensor_num_feats + 1, bond_dim, dtype=torch.float)
    batch_core[:, 0, :] = eye
    batch_core[:, 1:, :] = tensor_core[:, :, :]
    self.tensor_core = batch_core


    def forward(self,inputs):
        return torch.einsum('i,j,ijk->k',self.alpha, inputs, self.tensor_core)

m = module()
x = torch.rand(9)
y = torch.rand(4)
num_epoch = 4

loss_fun = torch.nn.MSELoss()
optimizer = torch.optim.Adam(m.parameters(), lr=1e-3, weight_decay=0)

for num in range(num_epoch):
    out = m(x)
    loss = loss_fun(out,y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

, я получаю сообщение

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Но не получаю это сообщение, если я заменяю

tensor_core = torch.nn.Parameter(tensor_core)
self.register_parameter('tensor',tensor_core)
eye = torch.eye(bond_dim,bond_dim, requires_grad = False, dtype=torch.float)
batch_core = torch.zeros(bond_dim, tensor_num_feats + 1, bond_dim, dtype=torch.float)
batch_core[:, 0, :] = eye
batch_core[:, 1:, :] = tensor_core[:, :, :]
self.tensor_core = batch_core

на

    self.real_core = torch.nn.Parameter(tensor_core)
    eye = torch.eye(bond_dim,bond_dim, requires_grad = False, dtype=torch.float)
    eye = eye.unsqueeze(1)
    batch_core = torch.cat((eye,self.real_core),1)
    self.tensor_core = batch_core

Что здесь происходит?

...