Как заменить инфы, чтобы избежать нан градиентов в PyTorch - PullRequest
1 голос
/ 19 июня 2019

Мне нужно вычислить log(1 + exp(x)), а затем использовать автоматическое дифференцирование на нем.Но для слишком большого x он выводит inf из-за возведения в степень:

>>> x = torch.tensor([0., 1., 100.], requires_grad=True)
>>> x.exp().log1p()
tensor([0.6931, 1.3133,    inf], grad_fn=<Log1PBackward>)

Поскольку log(1 + exp(x)) ≈ x для большого x, я думал, что смогу заменить infs на xиспользуя torch.where.Но при этом я все равно получаю nan за градиент слишком больших значений.Знаете ли вы, почему это происходит и есть ли другой способ заставить это работать?

>>> exp = x.exp()
>>> y = x.where(torch.isinf(exp), exp.log1p())  # Replace infs with x
>>> y  # No infs
tensor([  0.6931,   1.3133, 100.0000], grad_fn=<SWhereBackward>)
>>> y.sum().backward()  # Automatic differentiation
>>> x.grad  # Why is there a nan and how can I get rid of it?
tensor([0.5000, 0.7311,    nan])

Ответы [ 2 ]

1 голос
/ 20 июня 2019

Но при слишком большом x выводит inf из-за возведения в степень

Вот почему x никогда не должно быть слишком большим. Идеально должно быть в диапазоне [-1, 1]. Если это не так, вы должны нормализовать свои входные данные.

0 голосов
/ 19 июня 2019

Обходной путь, который я нашел, состоит в том, чтобы вручную реализовать функцию Log1PlusExp с ее обратным аналогом.Тем не менее, это не объясняет плохое поведение torch.where в вопросе.

>>> class Log1PlusExp(torch.autograd.Function):
...     """Implementation of x ↦ log(1 + exp(x))."""
...     @staticmethod
...     def forward(ctx, x):
...         exp = x.exp()
...         ctx.save_for_backward(x)
...         return x.where(torch.isinf(exp), exp.log1p())
...     @staticmethod
...     def backward(ctx, grad_output):
...         x, = ctx.saved_tensors
...         return grad_output / (1 + (-x).exp())
... 
>>> log_1_plus_exp = Log1PlusExp.apply
>>> x = torch.tensor([0., 1., 100.], requires_grad=True)
>>> log_1_plus_exp(x)  # No infs
tensor([  0.6931,   1.3133, 100.0000], grad_fn=<Log1PlusExpBackward>)
>>> log_1_plus_exp(x).sum().backward()
>>> x.grad  # And no nans!
tensor([0.5000, 0.7311, 1.0000])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...