Реализация U-сети с пропуском соединения в Pytorch - PullRequest
0 голосов
/ 23 апреля 2019

Я хочу реализовать Unet с пропуском подключений в pytorch. Моя ванильная сеть выглядит так:

def createVanillaGan(self):

    # Encoding layers
    self.conv_1 = nn.Sequential(
    nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False), # in channel, out channel, filter kernel size
    nn.BatchNorm2d( 64 ),
    nn.LeakyReLU( 0.1 )
    )

    self.conv_2 = nn.Sequential(
    nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), # in channel, out channel, filter kernel size
    nn.BatchNorm2d( 128 ),
    nn.LeakyReLU( 0.1 )
    )
    .
    .
    .
    self.conv_trans_4 = nn.Sequential(
    nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    )

    self.conv_trans_5 = nn.Sequential(
    nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1, bias=False),
    #nn.BatchNorm2d(3),
    nn.Tanh()
    )

Мой прямой проход выглядит так

def forward(self, data):
    output1 = self.conv_1(data) #10x64x128x128
    output2 = self.conv_2(output1) #10x128x64x64
    output3 = self.conv_3(output2) #10x256x32x32
    .
    .
    .
    #decoding
    output4_de = self.conv_trans_4(output3_de) #10x64x128x128

    output5_de = self.conv_trans_5(output4_de) #10x128x64x64

Что я хочу сделать, это объединить выходные данные в прямом проходе. Могу ли я просто выполнить torch.cat ((output5, output5_de), 1) в прямом проходе или мне нужно внести изменения и в createVanillaGan (self)? Мне интересно, как это повлияет на обратное распространение, или я могу просто изменить ход вперед и сделать это?

...