torch.einsum 'RuntimeError: несоответствие размеров для операнда 0: уравнение 4 тензор 2' - PullRequest
0 голосов
/ 20 июня 2020

Я пытаюсь вручную вычислить градиент матрицы, и я могу сделать это, используя numpy, но я не знаю, как сделать то же самое в pytorch. уравнение в NumPy - это

def grad(A, W0, W1, X):
    dim = A.shape
    assert len(dim) == 2
    A_rows = dim[0]
    A_cols = dim[1]    
    gradient = (np.einsum('ik, jl', np.eye(A_cols, A_rows), (((A).dot(X)).dot(W0)).dot(W1).T) + np.einsum('ik, jl', A, ((X).dot(W0)).dot(W1).T))
    return gradient

Я написал функцию в pytorch, но это дает мне сообщение об ошибке: «RuntimeError: несоответствие размеров для операнда 0: тензор 2 уравнения 4»

функция, которую я написал с помощью pytorch, - это

def torch_grad(A, W0, W1, X):
    dim = A.shape
    A_rows = dim[0]
    A_cols = dim[1]
    W0W1 = torch.mm(W0, W1)
    AX = torch.mm(A, X)
    AXW0W1 = torch.mm(AX, W0W1)
    XW0W1 = torch.mm(X, W0W1)
    print(torch.eye(A_cols, A_rows).shape, torch.t(AXW0W1).shape)
    e1 = torch.einsum('ik jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
    e2 = torch.einsum('ik, jl', A, torch.t(XW0W1))
    return e1 + e2

Я был бы признателен, если бы кто-нибудь показал мне, как реализовать код numpy в pytorch. Спасибо!

1 Ответ

0 голосов
/ 20 июня 2020

Вам не хватает запятой в первом вызове torch.einsum.

e1 = torch.einsum('ik, jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))

Помимо опечатки, градиенты по отношению к A вычисляются не так, и это не удается, когда A и AX имеют разные размеры, которые в противном случае были бы действительны для прямого прохода. Для e1 это должно быть матричное умножение, возможно, это было вашим намерением, и в этом случае torch.einsum должно быть 'ik, kl', но это просто слишком сложный способ выполнить матричное умножение, а использование torch.mm - это проще и эффективнее. И e2 не участвует ни в каких вычислениях, которые были выполнены относительно A, поэтому он не является частью градиентов.

def torch_grad(A, W0, W1, X):
    # Forward
    W0W1 = torch.mm(W0, W1)
    AX = torch.mm(A, X)
    AXW0W1 = torch.mm(AX, W0W1)
    XW0W1 = torch.mm(X, W0W1)

    # Backward / Gradients
    rows, cols = AXW0W1.size()
    grad_AX = torch.mm(torch.eye(rows, cols), W0W1.t())
    grad_A = torch.mm(grad_AX, X.t())
    return grad_A

# Autograd version to verify that the gradients are correct
def torch_autograd(A, W0, W1, X):
    # Forward
    W0W1 = torch.mm(W0, W1)
    AX = torch.mm(A, X)
    AXW0W1 = torch.mm(AX, W0W1)
    XW0W1 = torch.mm(X, W0W1)
    
    # Backward / Gradients
    rows, cols = AXW0W1.size()
    AXW0W1.backward(torch.eye(rows, cols))
    return A.grad

# requires_grad=True for the autograd version to track
# gradients with respect to A
A = torch.randn(3, 4, requires_grad=True)
X = torch.randn(4, 5)
W0 = torch.randn(5, 6)
W1 = torch.randn(6, 5)

grad_result = torch_grad(A, W0, W1, X)
autograd_result = torch_autograd(A, W0, W1, X)

torch.equal(grad_result, autograd_result) # => True
...