Я пытаюсь реализовать и обучить оригинальную модель U-Net , но я застрял, когда пытаюсь обучить модель с использованием ISBI Challenge Dataset .
Согласно исходной модели U-Net, сеть выводит изображение с 2 каналами и размером 388 x 388. Итак, мой загрузчик данных для обучения генерирует тензор размером [batch, channel = 1, ширина = 572, высота = 572] для входных изображений и [пакет, каналы = 2, ширина = 388, ширина = 388] для целевых / выходных изображений.
Моя проблема на самом деле заключается в том, что когда я пытаюсь использовать nn.CrossEntropyLoss (), возникает следующая ошибка:
RuntimeError: недопустимый аргумент 3: поддерживаются только пакеты пространственных целей (трехмерные тензоры), но получены цели измерения: 4 в / opt / conda / conda-bld / pytorch_1556653099582 / work / aten / src / THNN / generic / SpatialClassNLLCriterion. C: 59
Я только начинаю с PyTorch (новичок здесь) ... так что я буду очень признателен, если кто-нибудь поможет мне преодолеть эту проблему.
Исходный код доступен на GitHub:
https://github.com/dalifreire/cnn_unet_pytorch
https://github.com/dalifreire/cnn_unet_pytorch/blob/master/unet_pytorch.ipynb
С уважением!