Я хочу использовать оптимизатор в прямом проходе пользовательской функции, но он не работает. Мой код выглядит следующим образом:
class MyFct(Function):
@staticmethod
def forward(ctx, *args):
input, weight, bias = args[0], args[1], args[2]
y = torch.tensor([[0]], dtype=torch.float, requires_grad=True) #initial guess
loss_fn = lambda y_star: (input + weight - y_star)**2
learning_rate = 1e-4
optimizer = torch.optim.Adam([y], lr=learning_rate)
for t in range(5000):
y_star = y
print(y_star)
loss = loss_fn(y_star)
if t % 100 == 99:
print(t, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
return y_star
И это мои тестовые входы:
x = torch.tensor([[2]], dtype=torch.float, requires_grad=True)
w = torch.tensor([[2]], dtype=torch.float, requires_grad=True)
y = torch.tensor([[6]], dtype=torch.float)
fct= MyFct.apply
y_hat = fct(x, w, None)
Я всегда получаю RuntimeError: элемент 0 тензоров не требует grad и не требует есть grad_fn .
Кроме того, я проверил оптимизацию вне форварда, и она работает, так что я думаю, что-то с контекстом? Согласно документации «Аргументы Tensor, которые отслеживают историю (т. Е. С require_grad = True), будут преобразованы в аргументы, которые не отслеживают историю до вызова, и их использование будет зарегистрировано на графике», см. https://pytorch.org/docs/stable/notes/extending.html. Это проблема? Есть ли способ обойти это?
Я новичок в PyTorch, и мне интересно, что я пропускаю. Любая помощь и объяснение приветствуется.