Как обучить оригинальную модель U-Net с PyTorch? - PullRequest
0 голосов
/ 26 июня 2019

Я пытаюсь реализовать и обучить оригинальную модель U-Net , но я застрял, когда пытаюсь обучить модель с использованием ISBI Challenge Dataset .

Согласно исходной модели U-Net, сеть выводит изображение с 2 каналами и размером 388 x 388. Итак, мой загрузчик данных для обучения генерирует тензор размером [batch, channel = 1, ширина = 572, высота = 572] для входных изображений и [пакет, каналы = 2, ширина = 388, ширина = 388] для целевых / выходных изображений.

Моя проблема на самом деле заключается в том, что когда я пытаюсь использовать nn.CrossEntropyLoss (), возникает следующая ошибка:

RuntimeError: недопустимый аргумент 3: поддерживаются только пакеты пространственных целей (трехмерные тензоры), но получены цели измерения: 4 в / opt / conda / conda-bld / pytorch_1556653099582 / work / aten / src / THNN / generic / SpatialClassNLLCriterion. C: 59

Я только начинаю с PyTorch (новичок здесь) ... так что я буду очень признателен, если кто-нибудь поможет мне преодолеть эту проблему.

Исходный код доступен на GitHub:

https://github.com/dalifreire/cnn_unet_pytorch https://github.com/dalifreire/cnn_unet_pytorch/blob/master/unet_pytorch.ipynb

С уважением!

...