PyTorch - Перезаписанные переменные остаются в графе? - PullRequest
0 голосов
/ 11 октября 2018

Мне интересно, хранятся ли в вычислительном графе PyTorch тензоры PyTorch, в которых переменные Python перезаписываются?


Итак, вот небольшой пример, где у меня есть модель RNN, в которой скрытые состояния (и некоторые другие переменные) сбрасываются после каждой итерации, backward() вызывается позже.

Пример:

for i in range(5):
   output = rnn_model(inputs[i])
   loss += criterion(output, target[i])
   ## hidden states are overwritten with a zero vector
   rnn_model.reset_hidden_states() 
loss.backward()

Итак, мой вопрос:

  • существует проблема перезаписи скрытых состояний перед вызовом backward()?

  • Или вычислительный граф хранит в памяти необходимую информацию о скрытых состояниях предыдущих итераций для вычисления градиентов?

  • Редактировать: было бы здорово иметь официальное заявление для этого.например, указав, что все переменные, относящиеся к CG, сохранены - независимо от того, есть ли еще другие ссылки на python для этих переменных.Я предполагаю, что в самом графе есть ссылка, не позволяющая сборщику мусора удалить его.Но я хотел бы знать, так ли это на самом деле.

Заранее спасибо!

1 Ответ

0 голосов
/ 12 октября 2018

Я думаю, что все в порядке, прежде чем вернуться назад.График сохраняет необходимую информацию.

class A (torch.nn.Module):
     def __init__(self):
         super().__init__()
         self.f1 = torch.nn.Linear(10,1)
     def forward(self, x):
         self.x = x 
         return torch.nn.functional.sigmoid (self.f1(self.x))
     def reset_x (self):
        self.x = torch.zeros(self.x.shape) 
net = A()
net.zero_grad()
X = torch.rand(10,10) 
loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
loss.backward()
params = list(net.parameters())
for i in params: 
    print(i.grad)
net.zero_grad() 

loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
net.reset_x()
print (net.x is X)
del X
loss.backward()     
params = list(net.parameters())
for i in params:
    print(i.grad)

В приведенном выше коде я печатаю грады с / без сброса ввода x.Градиент зависит от x, и его сброс не имеет значения.Поэтому я думаю, что граф сохраняет информацию для обратной операции.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...