Как вычислить градиент степенной функции по экспоненте в PyTorch? - PullRequest
0 голосов
/ 22 июня 2019

Я пытаюсь вычислить градиент

out = x.sign()*torch.pow(x.abs(), alpha)

относительно альфа.

Я попробовал следующее:

class Power(nn.Module):
  def __init__(self, alpha=2.):
    super(Power, self).__init__()
    self.alpha = nn.Parameter(torch.tensor(alpha))

  def forward(self, x):
    return x.sign()*torch.abs(x)**self.alpha

но этот класс продолжает давать мне nan на тренировке моей сети. Я ожидаю увидеть что-то вроде grad=out*torch.log(x), но не могу добраться до него. Этот код, например, ничего не возвращает:

alpha_rooting = Power()
x = torch.randn((1), device='cpu', dtype=torch.float)
out = (alpha_rooting(x)).sum()
out.backward()
print(out.grad)

Я пытаюсь использовать autograd для этого тоже не случайно. Как мне решить эту проблему? Спасибо.

1 Ответ

0 голосов
/ 22 июня 2019

Класс Power(), который вы написали, работает, как и ожидалось. Существует проблема в том, как вы используете это на самом деле. Градиенты хранятся в .grad этой переменной, а не в переменной out, как вы использовали выше. Вы можете изменить код, как показано ниже.

alpha_rooting = Power()
x = torch.randn((1), device='cpu', dtype=torch.float)
out = (alpha_rooting(x)).sum()

# compute gradients of all parameters with respect to out (dout/dparam)
out.backward()
# print gradient of alpha
# Note that gradients are store in .grad of parameter not out variable
print(alpha_rooting.alpha.grad)

# compare if it is approximately correct to exact grads
err = (alpha_rooting.alpha.grad - out*torch.log(x))**2 
if (err <1e-8):
    print("Gradients are correct")
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...