Как использовать оптимизатор в прямом проходе в PyTorch - PullRequest
0 голосов
/ 27 февраля 2020

Я хочу использовать оптимизатор в прямом проходе пользовательской функции, но он не работает. Мой код выглядит следующим образом:

class MyFct(Function):

   @staticmethod
   def forward(ctx, *args):
       input, weight, bias = args[0], args[1], args[2]

       y = torch.tensor([[0]], dtype=torch.float, requires_grad=True) #initial guess
       loss_fn = lambda y_star: (input + weight - y_star)**2

       learning_rate = 1e-4
       optimizer = torch.optim.Adam([y], lr=learning_rate)
       for t in range(5000):
           y_star = y
           print(y_star)
           loss = loss_fn(y_star)
           if t % 100 == 99:
               print(t, loss.item())
           optimizer.zero_grad()
           loss.backward()
           optimizer.step() 

       return y_star

И это мои тестовые входы:

x = torch.tensor([[2]], dtype=torch.float, requires_grad=True)
w = torch.tensor([[2]], dtype=torch.float, requires_grad=True)
y = torch.tensor([[6]], dtype=torch.float)

fct= MyFct.apply
y_hat = fct(x, w, None)

Я всегда получаю RuntimeError: элемент 0 тензоров не требует grad и не требует есть grad_fn .

Кроме того, я проверил оптимизацию вне форварда, и она работает, так что я думаю, что-то с контекстом? Согласно документации «Аргументы Tensor, которые отслеживают историю (т. Е. С require_grad = True), будут преобразованы в аргументы, которые не отслеживают историю до вызова, и их использование будет зарегистрировано на графике», см. https://pytorch.org/docs/stable/notes/extending.html. Это проблема? Есть ли способ обойти это?

Я новичок в PyTorch, и мне интересно, что я пропускаю. Любая помощь и объяснение приветствуется.

1 Ответ

0 голосов
/ 27 февраля 2020

Я думаю, что нашел ответ здесь: https://github.com/pytorch/pytorch/issues/8847, т.е. мне нужно обернуть опримизацию с помощью with torch.enable_grad():.

Однако я до сих пор не понимаю, зачем это нужно преобразовать оригинальные Тензорные в те, которые не отслеживают историю в forward ().

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...