Я пытаюсь построить простой пример с нулевым CrossEntropyLoss, с тремя классами и размером мини-пакета = 1:
ce_loss = nn.CrossEntropyLoss()
outputs = torch.tensor([[1,0,3]], dtype=float)
targets = torch.tensor([1,0,3], dtype=torch.int64)
loss = ce_loss(outputs, targets)
Что приводит к ошибке:
ValueError: Expected input batch_size (1) to match target batch_size (3).
Что не так и откуда берется «target batch_size (3)»?
Обновление 1
В соответствии с моим пониманием документов PyTorch я пробую следующие случаи. Пожалуйста, укажите, где я не прав.
Пример 1:
outputs = torch.tensor([[1,1]], dtype=float)
targets = torch.tensor([1], dtype=torch.int64)
loss = ce_loss(outputs, targets)
Вывод:
loss: 0.6931471805599453
Вопросы:
1) У меня есть один класс равно 1 и два прогноза для этого класса каждый равен 1, правильно?
2) Почему ненулевые потери?
Пример 2:
outputs = torch.tensor([[1]], dtype=float)
targets = torch.tensor([1], dtype=torch.int64)
loss = ce_loss(outputs, targets)
Ошибка:
IndexError: Target 1 is out of bounds.
Вопросы:
1) У меня есть один класс, равный 1, и одно предсказание для этого класса, равное 1, верно?
2) Почему ошибка индекса?
Пример 3:
outputs = torch.tensor([[1,1],[2,2]], dtype=float)
targets = torch.tensor([1, 2], dtype=torch.int64)
loss = ce_loss(outputs, targets)
Ошибка:
IndexError: Target 2 is out of bounds.
Вопросы:
1 ) У меня есть два класса, равные 1 и 2, два прогноза для этих классов, равные 1 и 2 (размер партии = 2), верно?
2) Почему ошибка индекса?
--- Дополнительные вопросы: ----
Какую форму должны иметь выходные данные и цели для следующих случаи:
1) Бинарная классификация, один временной шаг, одно значение, предсказанное одним выходным нейроном, размер пакета = 1
2) Бинарная классификация, пять различных временных шагов, 5 значений, предсказанных одним выходной нейрон, размер пакета = 5
3) семь классов, один временной шаг, 7 значений, предсказанных 7 выходными нейронами, размер пакета = 1
4) семь классов, пять различных временных шагов, 5 наборов из 7 значений предсказаны 7 выходными нейронами, размер партии = 5