PyTorch - BCELoss: ValueError: Target и input должны иметь одинаковое количество элементов - PullRequest
1 голос
/ 30 июня 2019

Когда я использую BCELoss в качестве функции потерь моей нейронной сети, получаю ValueError: Target and input must have the same number of elements.

Вот мой код для фазы тестирования (что является довольно типичным кодом фазы тестирования):

network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)

    output = network(data)
    output = output.to(device)
    test_loss += loss_function(output, target).item() # error happens here
    _, predicted = torch.max(output.data, 1)
    correct += (predicted == target).sum().item()

Форма переменной output равна [1000, 10], поскольку существует 10 целевых классов (в MNIST набор данных), а форма переменной target имеет значение [1000], поскольку содержит целевые классы тестируемой партии (размер партии для теста установлен на 10).Итак, вопрос в том, как я могу применить BCELoss в качестве функции потерь сети CNN?

ps Набор данных, который я использую, - это набор данных MNIST , который предоставляется torchvision library.

ps Ответ на аналогичный вопрос, приведенный здесь , не предлагает решения для моего случая.

1 Ответ

1 голос
/ 30 июня 2019

Ответ , на который вы претендуете, не предлагает решения, фактически решает вашу проблему:

ваши цели не завершены!Если есть несколько классов, вы должны работать с torch.nn.CrossEntropyLoss вместо torch.nn.BCELoss()

Напомним, torch.nn.BCELoss() предназначен для использованиядля задачи классификации c независимых двоичных атрибутов для каждого входного примера.С другой стороны, перед вами стоит задача классифицировать каждый вывод в один из c взаимоисключающих классов.Для этой задачи вам нужна другая потеря, torch.nn.CrossEntropyLoss().
Различные задачи, представленные различными функциями потерь, требуют различного контроля (меток).Если вы хотите классифицировать каждый пример к одному из c взаимоисключающих классов, вам потребуется только одна целочисленная метка для каждого примера (как у вас в примере с mnist).Однако, если вы хотите классифицировать каждый пример на c независимых двоичных атрибутов, вам необходимо для каждого примера c двоичные метки - и именно поэтому Pytorch выдает ошибку.

...