Да, в этом случае он действует так же, как torch.nn.MSELoss
, и в целом он называется Huber Loss
.
В силу своей природы threshold
не имеет особого смысла, давайте посмотрим на пример, почему это так:
Как это работает
Давайте сравним ошибки, превышающие 1.0
в случае MSELoss
и SmoothL1Loss
. Предположим, что наша абсолютная ошибка (|f(x) - y|
) равна 10
. MSELoss
даст ему значение 100
(или 50
в случае реализации pytorch
), в то время как SmoothL1Loss
даст только это значение 10
, следовательно, модель не будет наказываться sh, так много для больших ошибок.
В случае значения ниже 1.0
SmoothL1Loss
наказывает модель меньше L1Loss
. Например, 0.5
станет 0.5*0.5
, поэтому 0.25
для Хьюбера и 0.5
для L1Loss
.
Это не "лучшее из обоих миров", это зависит от того, что вы ищете. Mean Squared Error
- усиливает большие ошибки и преуменьшает мелкие, L1Loss
дает ошибкам "равный" вес, скажем.
Пользовательская функция потерь
Хотя обычно это не делается, вы можете использовать любые потери Функция, которую вы хотите, в зависимости от вашей цели (порог здесь не имеет смысла). Если вы хотите, чтобы меньшие ошибки были более серьезными, вы можете, например, сделать что-то вроде этого:
import torch
def fancy_squared_loss(y_true, y_pred):
return torch.mean(torch.sqrt(torch.abs(y_true - y_pred)))
Для значения 0.2
вы получите ~0.447
, для 0.5
~0.7
и т. Д. на. Поэкспериментируйте и проверьте, существуют ли какие-либо конкретные c функции потерь для поставленной задачи, хотя я думаю, что вряд ли эти эксперименты дадут вам существенное повышение по сравнению с L1Loss
, если таковые имеются.
Пользовательский порог
Если вы действительно хотите установить пользовательский порог для MSELoss
и L1Loss
, вы можете реализовать его самостоятельно, хотя:
import torch
class CustomLoss:
def __init__(self, threshold: float = 0.5):
self.threshold = threshold
def __call__(self, predicted, true):
errors = torch.abs(predicted - true)
mask = errors < self.threshold
return (0.5 * mask * (errors ** 2)) + ~mask * errors
Все ниже threshold
получит MSELoss
, в то время как все выше будет L1Loss
.