grad_outputs в torch.autograd.grad (CrossEntropyLoss) - PullRequest
0 голосов
/ 13 января 2019

Я пытаюсь получить d(loss)/d(input). Я знаю, что у меня есть 2 варианта.

Первый вариант:

    loss.backward()
    dlossdx = x.grad.data

Второй вариант:

    # criterion = nn.CrossEntropyLoss(reduce=False)
    # loss = criterion(y_hat, labels)     
    # No need to call backward. 
    dlossdx = torch.autograd.grad(outputs = loss,
                                  inputs = x,
                                  grad_outputs = ? )

Мой вопрос: если я использую кросс-энтропийную потерю, что я должен передать как grad_outputs во втором варианте?

Должен ли я поставить d(CE)/d(y_hat)? Поскольку кросцентропия Pytorch содержит softmax, мне потребуется предварительно вычислить производную softmax с использованием дельты Кронекера.

Или я могу поставить d(CE)/d(CE), что torch.ones_like?

Концептуальный ответ в порядке.

1 Ответ

0 голосов
/ 13 января 2019

Давайте попробуем понять, как работают оба варианта.

Мы будем использовать эту настройку

import torch 
import torch.nn as nn
import numpy as np 
x = torch.rand((64,10), requires_grad=True)
net = nn.Sequential(nn.Linear(10,10))
labels = torch.tensor(np.random.choice(10, size=64)).long()
criterion = nn.CrossEntropyLoss()

Первый вариант

loss = criterion(net(x), labels)
loss.backward(retain_graph=True)
dloss_dx = x.grad

Обратите внимание, что вы не передали параметры градиенту, потому что потеря является скалярной величиной, если вы вычисляете потерю как вектор, тогда вы должны передать

Второй вариант

dloss_dx2 = torch.autograd.grad(loss, x)

Это вернет кортеж, и вы можете использовать первый элемент в качестве градиента х.

Обратите внимание, что torch.autograd.grad возвращает сумму dout / dx, если вы передаете несколько выходов в виде кортежей. Но так как потеря скалярна, вам не нужно передавать grad_outputs, так как по умолчанию она будет считаться равной единице.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...