Как работает функция потерь MSE в pytorch? - PullRequest
0 голосов
/ 04 декабря 2018

В следующем коде (извлеченном из SentEval ) определена структура нейронной сети, которая отображает 1024 действительных числа на 5 выходных предсказаний.Проблема состоит в том, чтобы оценить взаимосвязь между двумя предложениями (каждое представлено 512 признаками).Родство - это число в [1,5].Я думаю, что если бы числа относительности обучения были в {1,2,3,4,5}, cross entropy была бы лучшей функцией потерь, но так как в обучающем наборе мы имеем реальные числа соотношения в [1,5],MSE используется в качестве функции потерь.

Вопрос: Поскольку для каждого входа сеть выдает 5 вероятностных чисел, как MSE вычисляется между действительным числом и 5 вероятностными числами?

from torch import nn

inputdim = 1024
nclasses = 5
model = nn.Sequential(
            nn.Linear(inputdim, nclasses),
            nn.Softmax(dim=-1),
        )
loss_fn = nn.MSELoss()
...