Как предотвратить инф при работе с экспоненциальной - PullRequest
1 голос
/ 28 марта 2020

Я пытаюсь создать функцию в сети с настраиваемыми параметрами. В моей функции у меня есть экспонента, которая для больших тензорных значений уходит в бесконечность. Каков наилучший способ избежать этого?

Функция выглядит следующим образом:

step1 = Pss-(k*Pvv)
step2 = step1*s
step3 = torch.exp(step2)
step4 = torch.log10(1+step3)
step5 = step4/s

#or equivalently
# train_curve = torch.log(1+torch.exp((Pss-k*Pvv)*s))/s

Если это облегчает понимание, функция basei c будет log10 (1 + e ^ (x-const) * 10) / 10. Экспонента внутри журнала становится слишком большой и переходит в инф.

Я думаю, что мне, возможно, придется нормализовать мой тензор x, и это будет означать также нормализацию констант и остальной части функции. Кто-нибудь мог бы подумать о том, как лучше всего go об этом?

Большое спасибо.

1 Ответ

1 голос
/ 29 марта 2020

Одним из решений является использование более стабильных вычислений. Обратите внимание, что log(1 + exp(x)) приблизительно равно x, когда x достаточно велико. Интуитивно это можно наблюдать, заметив, что, например, exp(50) составляет приблизительно 5.18e+21, для которого добавление 1 не будет иметь эффекта при использовании 32-битной арифметики с плавающей точкой c, как это делает PyTorch. Дальнейшая проверка с использованием калькулятора произвольной точности показывает, что ошибка в этом приближении при 50 находится далеко за пределами максимальной 32-разрядной точности с плавающей запятой (которая составляет около 7 десятичных цифр).

Использование этого Для информации мы можем реализовать простую кусочную функцию в PyTorch, для которой мы используем log1p(exp(x)) для значений меньше 50 и x для значений больше 50. Также обратите внимание, что эта функция совместима с autograd

def log1pexp(x):
    # more stable version of log(1 + exp(x))
    y = torch.zeros(x.shape, device=x.device)
    mask_low = x < 50
    mask_high = torch.logical_not(mask_low)
    y[mask_low] = torch.log1p(torch.exp(torch.masked_select(x, mask_low)))
    y[mask_high] = torch.masked_select(x, mask_high)
    return y

Это дает нам большую часть пути к решению, поскольку вы действительно хотите оценить torch.log10(1+torch.exp((Pss-k*Pvv)*s))/s

Теперь мы можем использовать нашу новую функцию log1pexp для вычисления этого выражения, не беспокоясь о бесконечности

(log1pexp((Pss - k*Pvv)*s) / math.log(10)) / s

и учитывайте преобразование из натурального бревна в бревно-10 путем деления на log(10).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...