Я хочу реализовать Unet с пропуском подключений в pytorch. Моя ванильная сеть выглядит так:
def createVanillaGan(self):
# Encoding layers
self.conv_1 = nn.Sequential(
nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False), # in channel, out channel, filter kernel size
nn.BatchNorm2d( 64 ),
nn.LeakyReLU( 0.1 )
)
self.conv_2 = nn.Sequential(
nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), # in channel, out channel, filter kernel size
nn.BatchNorm2d( 128 ),
nn.LeakyReLU( 0.1 )
)
.
.
.
self.conv_trans_4 = nn.Sequential(
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.conv_trans_5 = nn.Sequential(
nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1, bias=False),
#nn.BatchNorm2d(3),
nn.Tanh()
)
Мой прямой проход выглядит так
def forward(self, data):
output1 = self.conv_1(data) #10x64x128x128
output2 = self.conv_2(output1) #10x128x64x64
output3 = self.conv_3(output2) #10x256x32x32
.
.
.
#decoding
output4_de = self.conv_trans_4(output3_de) #10x64x128x128
output5_de = self.conv_trans_5(output4_de) #10x128x64x64
Что я хочу сделать, это объединить выходные данные в прямом проходе. Могу ли я просто выполнить torch.cat ((output5, output5_de), 1) в прямом проходе или мне нужно внести изменения и в createVanillaGan (self)? Мне интересно, как это повлияет на обратное распространение, или я могу просто изменить ход вперед и сделать это?