Pytorch: Как функция .grad () возвращает результат? - PullRequest
0 голосов
/ 08 июня 2018

Я пытаюсь понять функцию grad () в python, я знаю о обратном распространении, но у меня есть некоторые сомнения в результате функции .grad ().

Так что, если у меня очень простая сеть, скажем, с одним входоми один единственный вес:

import torch
from torch.autograd import Variable
from torch import FloatTensor


a_tensor=Variable(FloatTensor([1]))
weight=Variable(FloatTensor([1]),requires_grad=True)

Теперь я запускаю это в ячейке ipython:

net_out=a_tensor*weight
loss=5-net_out
loss.backward()
print("atensor",a_tensor)
print('weight',weight)
print('net_out',net_out)
print('loss',loss)
print(weight.grad)

Во время первого запуска возвращается:

atensor tensor([ 1.])
weight tensor([ 1.])
net_out tensor([ 1.])
loss tensor([ 4.])
tensor([-1.])

Что правильнопотому что если я прав, то вычисление уравнения градиента будет здесь:

Теперь netout / w будет (w * a) по отношению к w ==> 1 * a
И убыток / netout (5-netout) по отношению к netout ==> (0-1)

Что было бы 1 * a * -1 ==> -1

Но проблема в том, что если я снова нажму ту же ячейку, ничего не меняя, то получу град -2, -3, -4... и т. д.

atensor tensor([ 1.])
weight tensor([ 1.])
net_out tensor([ 1.])
loss tensor([ 4.])
tensor([-2.])

следующий запуск:

atensor tensor([ 1.])
weight tensor([ 1.])
net_out tensor([ 1.])
loss tensor([ 4.])
tensor([-3.])

и т. д.

Я не понимаю, что там происходит, почему и как значение gradувеличение

1 Ответ

0 голосов
/ 08 июня 2018

Это потому, что вы не обнуляете градиенты.loss.backward() накапливает градиенты - это добавляет градиенты к существующим.Если вы не обнуляете градиент, запускайте loss.backward() снова и снова, просто продолжайте добавлять градиенты друг к другу.То, что вы хотите сделать, это обнулить градиенты после каждого шага, и вы увидите, что градиенты рассчитаны правильно.

Если вы построили сеть net (которая должна быть объектом класса nn.Module), вы можете обнулить градиенты, просто вызвав net.zero_grad().Если вы не построили net (или torch.optim объект), вам придется самостоятельно обнулять градиенты вручную.

Используйте там метод weight.grad.data.zero_().

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...