Использование функции активации Softmax после расчета убытка от BCEWithLogitLoss (двоичная кросс-энтропия + сигмоидальная активация) - PullRequest
1 голос
/ 27 мая 2020

Я просматриваю учебник по двоичной классификации, используя PyTorch, и здесь последний уровень сети - torch.Linear() с одним нейроном. (Имеет смысл), который даст нам единственный нейрон. как pred=network(input_batch)

После этого выбор функции потерь - loss_fn=BCEWithLogitsLoss() (что численно стабильно, чем использование сначала softmax, а затем вычисление потерь), которая применит функцию Softmax к выходу последнего слоя для дайте нам вероятность. поэтому после этого он вычислит двоичную кросс-энтропию, чтобы минимизировать потери.

loss=loss_fn(pred,true)

Меня беспокоит, что после всего этого автор использовал torch.round(torch.sigmoid(pred))

Почему это должно быть? Я имею в виду, что я знаю, что он получит вероятности предсказания в диапазоне [0,1], а затем округлит значения с порогом по умолчанию 0,5.

Не лучше ли использовать sigmoid один раз после последнего слой внутри сети, а не используя softmax и сигмоид в 2 разных местах, учитывая бинарную классификацию ??

Не лучше было бы просто

out = self.linear(batch_tensor)
return self.sigmoid(out)

, а затем вычислить BCE потерю и использовать argmax() для проверки точности ??

Я просто любопытно, может ли это быть действенной стратегией?

1 Ответ

2 голосов
/ 27 мая 2020

Кажется, вы думаете о бинарной классификации как о мультиклассовой классификации с двумя классами, но это не совсем правильно при использовании бинарного кросс-энтропийного подхода. Давайте начнем с пояснения цели двоичной классификации, прежде чем рассматривать какие-либо детали реализации.

Технически существует два класса, 0 и 1, но вместо того, чтобы рассматривать их как два отдельных класса, вы можете рассматривать их как противоположности друг друга. Например, вы хотите определить, был ли ответ StackOverflow полезным или нет. Два класса будут: «полезный» и «бесполезный» . Естественно, вы бы просто спросили «Был ли ответ полезен?» , отрицательный аспект опущен, и если бы это было не так, вы могли бы сделать вывод, что это было «бесполезно» . (Помните, что это бинарный случай, золотой середины нет).

Следовательно, ваша модель должна предсказывать только один класс, но во избежание путаницы с реальными двумя классами, которые можно выразить как: модель предсказывает вероятность возникновения положительного случая. В контексте предыдущего примера: Какова вероятность того, что ответ StackOverflow был полезен?

Sigmoid дает вам значения в диапазоне [0, 1] , что вероятности. Теперь вам нужно решить, когда модель достаточно уверена, чтобы быть положительной, определив порог. Чтобы сделать его сбалансированным, порог равен 0,5, поэтому, пока вероятность больше 0,5, он положительный (класс 1: «полезно »), в противном случае он отрицательный (класс 0: «бесполезен. "), что достигается округлением (например, torch.round(torch.sigmoid(pred))).

После этого выбор функции потерь составляет loss_fn=BCEWithLogitsLoss() (что численно стабильно, чем использование сначала softmax, а затем вычисление потерь), который применит функцию Softmax к выходным данным последнего слоя, чтобы дать нам вероятность.

Разве не лучше использовать сигмоид один раз после последнего слоя в сети, а не использовать softmax и сигмоидальный элемент в 2 разных местах, учитывая его двоичную классификацию ??

BCEWithLogitsLoss применяется сигмоид не Softmax, Softmax вообще не задействован. Из документации nn.BCEWithLogitsLoss :

Эта потеря объединяет слой сигмоид и BCELoss в одном классе. Эта версия более стабильна в числовом отношении, чем использование простого сигмоида , за которым следует BCELoss , поскольку, объединяя операции в один уровень, мы используем трюк логарифма суммы-экспонирования для численного стабильность.

Не применяя сигмоид в модели, вы получаете более численно стабильную версию двоичной кросс-энтропии, но это означает, что вам придется применять сигмоид вручную, если вы хотите сделать реальный прогноз вне обучения.

[...] и использовать argmax() для проверки точности ??

Опять же, вы думаете о сценарии с несколькими классами . У вас есть только один выходной класс, т.е. размер вывода [batch_size, 1] . Принятие argmax этого всегда даст вам 0, потому что это единственный доступный класс.

...