Одним из решений является использование более стабильных вычислений. Обратите внимание, что log(1 + exp(x))
приблизительно равно x
, когда x
достаточно велико. Интуитивно это можно наблюдать, заметив, что, например, exp(50)
составляет приблизительно 5.18e+21
, для которого добавление 1
не будет иметь эффекта при использовании 32-битной арифметики с плавающей точкой c, как это делает PyTorch. Дальнейшая проверка с использованием калькулятора произвольной точности показывает, что ошибка в этом приближении при 50 находится далеко за пределами максимальной 32-разрядной точности с плавающей запятой (которая составляет около 7 десятичных цифр).
Использование этого Для информации мы можем реализовать простую кусочную функцию в PyTorch, для которой мы используем log1p(exp(x))
для значений меньше 50 и x
для значений больше 50. Также обратите внимание, что эта функция совместима с autograd
def log1pexp(x):
# more stable version of log(1 + exp(x))
y = torch.zeros(x.shape, device=x.device)
mask_low = x < 50
mask_high = torch.logical_not(mask_low)
y[mask_low] = torch.log1p(torch.exp(torch.masked_select(x, mask_low)))
y[mask_high] = torch.masked_select(x, mask_high)
return y
Это дает нам большую часть пути к решению, поскольку вы действительно хотите оценить torch.log10(1+torch.exp((Pss-k*Pvv)*s))/s
Теперь мы можем использовать нашу новую функцию log1pexp
для вычисления этого выражения, не беспокоясь о бесконечности
(log1pexp((Pss - k*Pvv)*s) / math.log(10)) / s
и учитывайте преобразование из натурального бревна в бревно-10 путем деления на log(10)
.