Что такое убыток в PyTorch и как он рассчитывается - PullRequest
0 голосов
/ 08 апреля 2020

Я изучил это руководство в документации PyTorch для понимания Transfer Learning. Была одна строка, которую я не смог понять.

После того, как потери рассчитаны с использованием loss = criterion(outputs, labels), текущие потери рассчитаны с использованием running_loss += loss.item() * inputs.size(0) и, наконец, потери эпохи рассчитаны с использованием running_loss / dataset_sizes[phase].

Разве не loss.item() должно быть для всей мини-партии (пожалуйста, поправьте меня, если я ошибаюсь). то есть, если batch_size равно 4, loss.item() даст потерю для всего набора из 4 изображений. Если это правда, почему loss.item() умножается на inputs.size(0) при расчете running_loss? Разве этот шаг не является дополнительным умножением в этом случае?

Любая помощь будет принята с благодарностью. Спасибо!

Ответы [ 2 ]

0 голосов
/ 08 апреля 2020

, если batch_size равен 4, loss.item () даст потерю для всего набора из 4 изображений

Это зависит от того, как вычисляется loss. Помните, loss - тензор, как и любой другой тензор. В общем случае API PyTorch возвращают средние потери по умолчанию

"Потери усредняются по наблюдениям для каждой мини-партии."

t.item() для тензор t просто конвертирует его в значение по умолчанию python float32.

Что еще более важно, если вы новичок в PyTorch, вам может быть полезно знать, что мы используем t.item() для поддержания текущих потерь вместо t, потому что тензоры PyTorch хранят историю своих значений, которая может перегрузить ваш графический процессор очень скоро.

0 голосов
/ 08 апреля 2020

Это потому, что потери, заданные CrossEntropy или другими функциями потерь, делятся на количество элементов, то есть параметр уменьшения по умолчанию составляет mean.

факел .nn.CrossEntropyLoss (вес = нет, size_average = нет, ignore_index = -100, уменьшить = нет, сокращение = «среднее»)

Следовательно, loss.item() содержит потерю всей мини-партии , но делится на размер партии. Вот почему loss.item() умножается на размер партии, заданный inputs.size(0), при расчете running_loss.

...