pytorch - градиенты, не рассчитанные для параметров - PullRequest
3 голосов
/ 16 октября 2019
a = torch.nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float, device=device))
b = torch.nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float, device=device))
c = a + 1
d = torch.nn.Parameter(c, requires_grad=True,)
for epoch in range(n_epochs):
    yhat = d + b * x_train_tensor
    error = y_train_tensor - yhat
    loss = (error ** 2).mean()
    loss.backward()
    print(a.grad)
    print(b.grad)
    print(c.grad)
    print(d.grad)

Распечатывает

None
tensor([-0.8707])
None
tensor([-1.1125])

Как узнать градиент для a и c? переменная d должна оставаться параметром

1 Ответ

1 голос
/ 17 октября 2019

В основном, когда вы создаете новый тензор, например torch.nn.Parameter() или torch.tensor(), вы создаете листовой узел тензор.

И когда вы делаете что-то вроде c=a+1, c будет промежуточным узлом . Вы можете print(c.is_leaf) проверить, является ли тензор листовым узлом или нет. Pytorch не будет вычислять градиент промежуточного узла по умолчанию.

В вашем фрагменте кода a, b, d - это тензор конечных узлов, а c - промежуточный узел. c.grad будет None, так как pytorch не вычисляет градиент для промежуточного узла. a изолируется от графика, когда вы звоните loss.backword(). Вот почему a.grad также None.

Если вы измените код на этот

a = torch.nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float, device=device))
b = torch.nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float, device=device))
c = a + 1
d = c
for epoch in range(n_epochs):
    yhat = d + b * x_train_tensor
    error = y_train_tensor - yhat
    loss = (error ** 2).mean()
    loss.backward()
    print(a.grad) # Not None
    print(b.grad) # Not None
    print(c.grad) # None
    print(d.grad) # None

Вы обнаружите, что a и b имеют градиенты, но c.grad и d.grad равны None, потому что онипромежуточный узел.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...