Я думаю, что все в порядке, прежде чем вернуться назад.График сохраняет необходимую информацию.
class A (torch.nn.Module):
def __init__(self):
super().__init__()
self.f1 = torch.nn.Linear(10,1)
def forward(self, x):
self.x = x
return torch.nn.functional.sigmoid (self.f1(self.x))
def reset_x (self):
self.x = torch.zeros(self.x.shape)
net = A()
net.zero_grad()
X = torch.rand(10,10)
loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
loss.backward()
params = list(net.parameters())
for i in params:
print(i.grad)
net.zero_grad()
loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
net.reset_x()
print (net.x is X)
del X
loss.backward()
params = list(net.parameters())
for i in params:
print(i.grad)
В приведенном выше коде я печатаю грады с / без сброса ввода x.Градиент зависит от x, и его сброс не имеет значения.Поэтому я думаю, что граф сохраняет информацию для обратной операции.