Вам не хватает запятой в первом вызове torch.einsum
.
e1 = torch.einsum('ik, jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
Помимо опечатки, градиенты по отношению к A
вычисляются не так, и это не удается, когда A
и AX
имеют разные размеры, которые в противном случае были бы действительны для прямого прохода. Для e1
это должно быть матричное умножение, возможно, это было вашим намерением, и в этом случае torch.einsum
должно быть 'ik, kl'
, но это просто слишком сложный способ выполнить матричное умножение, а использование torch.mm
- это проще и эффективнее. И e2
не участвует ни в каких вычислениях, которые были выполнены относительно A
, поэтому он не является частью градиентов.
def torch_grad(A, W0, W1, X):
# Forward
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
# Backward / Gradients
rows, cols = AXW0W1.size()
grad_AX = torch.mm(torch.eye(rows, cols), W0W1.t())
grad_A = torch.mm(grad_AX, X.t())
return grad_A
# Autograd version to verify that the gradients are correct
def torch_autograd(A, W0, W1, X):
# Forward
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
# Backward / Gradients
rows, cols = AXW0W1.size()
AXW0W1.backward(torch.eye(rows, cols))
return A.grad
# requires_grad=True for the autograd version to track
# gradients with respect to A
A = torch.randn(3, 4, requires_grad=True)
X = torch.randn(4, 5)
W0 = torch.randn(5, 6)
W1 = torch.randn(6, 5)
grad_result = torch_grad(A, W0, W1, X)
autograd_result = torch_autograd(A, W0, W1, X)
torch.equal(grad_result, autograd_result) # => True