Вычислить градиент между скаляром и вектором в PyTorch - PullRequest
0 голосов
/ 17 октября 2019

Я пытаюсь скопировать код, написанный с помощью Theano, в PyTorch. В коде автор вычисляет градиент, используя

import theano.tensor as T    
gparams = T.grad(cost, params)

, а форма gparams равна (256, 240)

Я пытался использовать backward(), но, похоже, он невернуть что-нибудь. Есть ли эквивалент в grad в PyTorch?

Предположим, это мой ввод,

import torch
from torch.autograd import Variable 
cost = torch.tensor(1.6019)
params = Variable(torch.rand(1, 73, 240))

1 Ответ

0 голосов
/ 17 октября 2019

cost должен быть результатом операции, включающей params. Вы не можете вычислить градиент, просто зная значения двух тензоров. Вам также необходимо знать отношения. Вот почему Pytorch создает граф вычислений, когда вы выполняете тензорные операции. Например, допустим, что отношение равно

cost = torch.sum(params)

, тогда мы ожидаем, что градиент cost относительно params будет вектором единиц независимо от значения params.

Это можно рассчитать следующим образом. Обратите внимание, что вам нужно добавить флаг requires_grad, чтобы указать pytorch, что вы хотите backward обновить градиент при вызове.

# Initialize independent variable. Make sure to set requires_grad=true.
params = torch.tensor((1, 73, 240), requires_grad=True)

# Compute cost, this implicitly builds a computation graph which records
# how cost was computed with respect to params.
cost = torch.sum(params)

# Zero the gradient of params in case it already has something in it.
# This step is optional in this example but good to do in practice to
# ensure you're not adding gradients to existing gradients.
if params.grad is not None:
    params.grad.zero_()

# Perform back propagation. This is where the gradient is actually
# computed. It also resets the computation graph.
cost.backward()

# The gradient of params w.r.t to cost is now stored in params.grad.
print(params.grad)

Результат:

tensor([1., 1., 1.])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...