Как я узнаю, что какая-то часть моей реализации U-сети неверна? - PullRequest
0 голосов
/ 24 апреля 2019

Прямо сейчас я пытаюсь внедрить U-net для рисования изображений. Я получил свой код для запуска без ошибок, и при распечатке потерь в поездах и потерь при проверке он действительно сходится к некоторой точке. Изображения действительно напоминают правду о земле, но их качество довольно низкое. Когда я смотрю на значения пикселей в неокрашенной части, это далеко не сходится к истинным значениям. Кажется, будто он случайно меняет свои значения.

С другой стороны, я смог выполнить ту же задачу с Matconvnet, библиотекой MATLAB, так что, думаю, она должна работать и в pytorch.

Так что проблема в том, что я думаю, что у меня есть проблемы с реализацией U-net с pytorch. Тем не менее, Я не знаю, где я ошибаюсь в реализации, так как я не получаю никаких ошибок. Ниже моя реализация.

Я попытался создать заново, но это не сработало. Сначала я подумал, что это может быть проблема с гиперпараметрами или методами инициализации, поэтому я попытался также переключить их, но ни один из них не помог.

import torch
import torch.nn as nn
import torch.nn.functional as F

class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.bn = nn.BatchNorm2d(out_size)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.bn2 = nn.BatchNorm2d(out_size)
        self.activation = activation

    def forward(self, x):
        x1 = self.activation(self.bn(self.conv(x)))
        out = self.activation(self.bn2(self.conv2(x1)))

        return out

class UNetLastBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetLastBlock, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.bn = nn.BatchNorm2d(out_size)
        self.conv2 = nn.Conv2d(out_size,in_size, kernel_size, padding=1)
        self.bn2 = nn.BatchNorm2d(in_size)
        self.activation = activation

    def forward(self, x):
        x1 = self.activation(self.bn(self.conv(x)))
        out = self.activation(self.bn2(self.conv2(x1)))

        return out

class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, in_size, 2, stride=2)
        # Due to concat
        self.conv = nn.Conv2d(in_size * 2, in_size, kernel_size, padding=1)
        self.bn = nn.BatchNorm2d(in_size)
        self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.bn2 = nn.BatchNorm2d(out_size)
        self.activation = activation

    def forward(self, x, bridge):
        up = self.up(x)
        out = torch.cat([up, bridge], dim=1)
        out = self.activation(self.bn(self.conv(out)))
        out = self.activation(self.bn2(self.conv2(out)))

        return out


class UNet(nn.Module):
    def __init__(self, in_c, out_c):
        super(UNet, self).__init__()
        self.activation = F.relu

        self.pool1 = nn.MaxPool2d(2)
        self.pool2 = nn.MaxPool2d(2)
        self.pool3 = nn.MaxPool2d(2)
        self.pool4 = nn.MaxPool2d(2)

        self.conv_block2_64 = UNetConvBlock(in_c, 64)
        self.conv_block64_128 = UNetConvBlock(64, 128)
        self.conv_block128_256 = UNetConvBlock(128, 256)
        self.conv_block256_512 = UNetConvBlock(256, 512)
        self.conv_block512_1024 = UNetLastBlock(512, 1024)

        self.up_block1024_512 = UNetUpBlock(512, 256)
        self.up_block512_256 = UNetUpBlock(256, 128)
        self.up_block256_128 = UNetUpBlock(128, 64)
        self.up_block128_64 = UNetUpBlock(64, 64)

        self.last = nn.Conv2d(64, out_c, 1)


    def forward(self, x):
        block1 = self.conv_block2_64(x)
        pool1 = self.pool1(block1)

        block2 = self.conv_block64_128(pool1)
        pool2 = self.pool2(block2)

        block3 = self.conv_block128_256(pool2)
        pool3 = self.pool3(block3)

        block4 = self.conv_block256_512(pool3)
        pool4 = self.pool4(block4)

        block5 = self.conv_block512_1024(pool4)

        up1 = self.up_block1024_512(block5, block4)

        up2 = self.up_block512_256(up1, block3)

        up3 = self.up_block256_128(up2, block2)

        up4 = self.up_block128_64(up3, block1)

        return self.last(up4)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...