Простая потеря L1 в PyTorch - PullRequest
1 голос
/ 16 июня 2020

Я хочу рассчитать потерю L1 в нейронной сети, я наткнулся на этот пример на https://discuss.pytorch.org/t/simple-l2-regularization/139/2, но в этом коде есть некоторые ошибки.

Это действительно способ вычисления L1 Loss в NN или есть способ попроще?

l1_crit = nn.L1Loss()
reg_loss = 0
for param in model.parameters():
    reg_loss += l1_crit(param)

factor = 0.0005
loss += factor * reg_loss

Является ли это каким-либо образом эквивалентом простого выполнения:

loss = torch.nn.L1Loss()

Полагаю, что нет, потому что я не передаю никаких сетевых параметров. Просто проверяю, существует ли для этого функция.

1 Ответ

0 голосов
/ 16 июня 2020

Если я хорошо понимаю, вы хотите вычислить потерю L1 вашей модели (как вы говорите в начале). Однако я думаю, что вы могли запутаться в обсуждении на форуме pytorch.

Насколько я понимаю, на форумах Pytorch и в опубликованном вами коде автор пытается нормализовать веса сети с помощью регуляризации L1. Таким образом, он пытается обеспечить, чтобы значения весов попадали в разумный диапазон (не слишком большой, не слишком маленький). Это нормализация весов с использованием нормализации L1 (поэтому используется model.parameters()). Нормализация принимает значение на входе и производит нормализованное значение на выходе. Проверьте это для нормализации весов: https://pytorch.org/docs/master/generated/torch.nn.utils.weight_norm.html

С другой стороны, L1 Loss - это просто способ определить, как 2 значения отличаются друг от друга, поэтому "потеря" просто мера этой разницы. В случае потери L1 эта ошибка вычисляется с помощью средней абсолютной ошибки loss = |x-y|, где x и y - значения для сравнения. Таким образом, вычисление ошибок принимает 2 значения на входе и производит значение на выходе. Проверьте это для вычисления потерь: https://pytorch.org/docs/master/generated/torch.nn.L1Loss.html

Чтобы ответить на ваш вопрос: нет, приведенные выше фрагменты не эквивалентны, поскольку первый пытается выполнить нормализацию весов, а второй - вы пытаются подсчитать убыток. Это будет вычисление потерь с некоторым контекстом:

sample, target = dataset[i]
target_predicted = model(sample)
loss = torch.nn.L1Loss()
loss_value = loss(target, target_predicted)
...