Предположим, у меня есть слой layer
в модуле горелки, и я использую его дважды или более раз за один forward
шаг, таким образом, чтобы результат, полученный этим layer
, позже снова вводился в то же самое layer
. Может ли Pytorch autograd
правильно вычислить градуировку весов этого слоя?
Вот что я имею в виду:
import torch
import torch.nn as nn
import torch.nn.functional as F
class net(nn.Module):
def __init__(self,in_dim,out_dim):
super(net,self).__init__()
self.layer = nn.Linear(in_dim,out_dim,bias=False)
def forward(self,x):
x = self.layer(x)
x = self.layer(x)
return x
input_x = torch.tensor([10.])
label = torch.tensor([5.])
n = net(1,1)
loss_fn = nn.MSELoss()
out = n(input_x)
loss = loss_fn(out,label)
n.zero_grad()
loss.backward()
for param in n.parameters():
w = param.item()
g = param.grad
print('Input = %.4f; label = %.4f'%(input_x,label))
print('Weight = %.4f; output = %.4f'%(w,out))
print('Gradient w.r.t. the weight is %.4f'%(g))
print('And it should be %.4f'%(4*(w**2*input_x-label)*w*input_x))
И вывод (может быть другим на вашем компьютере, если начальное значение веса отличается):
Input = 10.0000; label = 5.0000
Weight = 0.9472; output = 8.9717
Gradient w.r.t. the weight is 150.4767
And it should be 150.4766
В этом примере я определил модуль только с одним линейным слоем (in_dim=out_dim=1
и без смещения). w
- вес этого слоя; input_x
- входное значение; label
- желаемое значение. Поскольку в качестве MSE выбрана потеря, формула потери:
((w^2)*input_x-label)^2
Вычисления вручную, мы имеем
dw/dx = 2*((w^2)*input_x-label)*(2*w*input_x)
Вывод моего примера выше показывает, что autograd
дает тот же результат, что и вычисленный вручную, давая мне повод полагать, что он может работать в этом случае. Но в реальном приложении слой может иметь входы и выходы более высоких измерений, нелинейную функцию активации после него и нейронную сеть может иметь несколько слоев.
Я хочу спросить: могу ли я доверять autograd
справиться с такой ситуацией, но намного сложнее, чем в моем примере? Как это работает, когда слой вызывается итеративно?