Мульти метка классификации в pytorch - PullRequest
0 голосов
/ 17 октября 2018

У меня проблема классификации по нескольким меткам.У меня 11 классов, около 4к примеров.Каждый пример может иметь от 1 до 4-5 меток.В данный момент я тренирую классификатор отдельно для каждого класса с log_loss.Как вы можете ожидать, обучение 11 классификатора занимает довольно много времени, и я хотел бы попробовать другой подход и обучить только 1 классификатору.Идея состоит в том, что последний слой этого классификатора будет иметь 11 узлов и будет выводить действительное число по классам, которое будет преобразовано в пробу сигмоидом.Потери, которые я хочу оптимизировать, являются средним значением log_loss для всех классов.

К сожалению, я какой-то нуб с pytorch, и даже читая исходный код потерь, я не могу понять, выполняет ли одна из уже существующих потерь именно то, что я хочу, или еслиЯ должен создать новую потерю, и если это так, я не знаю, как это сделать.

Чтобы быть очень конкретным, я хочу дать для каждого элемента партии один вектор размером 11(который содержит действительное число для каждой метки (чем ближе к бесконечности, тем ближе этот класс к 1) и 1 вектору размера 11 (который содержит 1 на каждой истинной метке), и может вычислять среднее значениеlog_loss на всех 11 ярлыках и оптимизировать мой классификатор на основе этой потери.

Любая помощь будет принята с благодарностью :)

1 Ответ

0 голосов
/ 17 октября 2018

Вы ищете torch.nn.BCELoss.Вот пример кода:

import torch

batch_size = 2
num_classes = 11

loss_fn = torch.nn.BCELoss()

outputs_before_sigmoid = torch.randn(batch_size, num_classes)
sigmoid_outputs = torch.sigmoid(outputs_before_sigmoid)
target_classes = torch.randint(0, 2, (batch_size, num_classes))  # randints in [0, 2).

loss = loss_fn(sigmoid_outputs, target_classes)

# alternatively, use BCE with logits, on outputs before sigmoid.
loss_fn_2 = torch.nn.BCEWithLogitsLoss()
loss2 = loss_fn_2(outputs_before_sigmoid, target_classes)
assert loss == loss2
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...