Функция потери и ее входы для двоичной классификации PyTorch - PullRequest
0 голосов
/ 05 декабря 2018

Я пытаюсь написать нейронную сеть для двоичной классификации в PyTorch, и меня смущает функция потерь.

Я вижу, что BCELoss - это общая функция, специально предназначенная для двоичной классификации.Я также вижу, что выходной слой из N выходов для N возможных классов является стандартным для общей классификации.Однако для двоичной классификации кажется, что это может быть 1 или 2 выхода.

Итак, я должен иметь 2 выхода (1 для каждой метки) и затем конвертировать мои 0/1 обучающие метки в [1,0] и [0,1] или использовать что-то вроде сигмоида для вывода с одной переменной?

Вот соответствующие фрагменты кода, которые вы можете увидеть:

self.outputs = nn.Linear(NETWORK_WIDTH, 2) # 1 or 2 dimensions?


def forward(self, x):
  # other layers omitted
  x = self.outputs(x)           
  return F.log_softmax(x)  # <<< softmax over multiple vars, sigmoid over one, or other?

criterion = nn.BCELoss() # <<< Is this the right function?

net_out = net(data)
loss = criterion(net_out, target) # <<< Should target be an integer label or 1-hot vector?

Спасибозаранее.

1 Ответ

0 голосов
/ 05 декабря 2018

Для двоичных выходов вы можете использовать 1 выходной единицы, поэтому:

self.outputs = nn.Linear(NETWORK_WIDTH, 1)

Затем вы используете sigmoid активациячтобы отобразить значения вашей единицы вывода в диапазоне от 0 до 1 (конечно, вам необходимоданные тренировки тоже таким образом) :

def forward(self, x):
    # other layers omitted
    x = self.outputs(x)           
    return F.sigmoid(x)  

Наконец, вы можете использовать torch.nn.BCELoss:

criterion = nn.BCELoss()

net_out = net(data)
loss = criterion(net_out, target)

Это должно хорошо работать для вас.

Вы также можете использовать torch.nn.BCEWithLogitsLoss, эта функция потери уже включает в себя функцию sigmoid, так что вы можете оставить ее в своем форварде.

Если вы, дляпо какой-то причине, хотите использовать 2 выходных единиц, это также возможно.Но тогда вам нужно использовать torch.nn.CrossEntropyLoss вместо BCELoss.Активация Softmax уже включена в эту функцию потери.

...