Я сам вычисляю градиент вручную, используя формулы для этой очень простой линейной сети с потерей MSE.
Затем я сравниваю градиент, который вычисляется PyTorch, и использую функцию allclose
из PyTorch, чтобы проверить правильность вычисления градиентов PyTorch (т. Е. Относительная разница между вычисляемыми вручную градиентами и pytorch маладостаточно).
Все тесты должны пройти, так как формулы верны.Но для некоторых семян это просто не так.
Так что, очевидно, PyTorch не делает ничего плохого, но поскольку формулы верны, это должно быть связано с некоторыми проблемами числовой нестабильности в формулах.
import torch
class Network(torch.nn.Module):
def __init__(self):
super(Network, self).__init__()
self.linear = torch.nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
loss = torch.nn.MSELoss()
for i in range(0, 1000):
torch.manual_seed(i)
X = torch.randn(100, 10)
y = torch.randn(100, 1)
model=Network()
model.train()
optimizer=torch.optim.SGD(model.parameters(),lr=1.)
optimizer.zero_grad()
output = loss(model(X), y)
output.backward()
torch_grads=[]
for p in model.parameters():
torch_grads.append(p.grad.detach().data)
#df/dW = (-2X.T*y+2*X.T*b+2*X.T*X*W)/nsamples
#df/db = (2*b-2*y+2*W.T*X.T).mean() (the mean comes from implicit broadcasting of b)
theory_grad_w = (-2 * torch.matmul(X.t(), y)
+2 * torch.matmul(torch.t(X), torch.ones((X.shape[0], 1)))* list(model.parameters())[1]
+2 * torch.matmul(torch.matmul(X.t(), X), list(model.parameters())[0].t())
) / float(X.shape[0])
theory_grad_w = theory_grad_w.t()
theory_grad_b = torch.mean(2 * list(model.parameters())[1]- 2 * y+ 2 * torch.matmul((list(model.parameters())[0]), torch.t(X)))
theory_grads = [theory_grad_w, theory_grad_b]
b=all([torch.allclose(u, d) for u, d in zip(torch_grads, theory_grads)])
if not(b):
print("i=%s, pass=%s"%(i, b))
Каковы источники наблюдаемой числовой нестабильности и как с ними бороться, чтобы тесты проходили все время.Это просто вопрос упорядочения операций по-другому?