деконволюция (convTranspose2D) инициализация с билинейной интерполяцией в pytorch? - PullRequest
1 голос
/ 17 февраля 2020

Я использую реализацию этой статьи в PyTorch (https://arxiv.org/pdf/1604.03650.pdf).

В статье говорится об инициализации слоев deconv с помощью билинейной интерполяции, которой нет в коде. Говорят, что если при деконволюции мы масштабируем данные с коэффициентом S, начальные веса определяются как: введите здесь описание изображения

Кто-нибудь знает, как я могу это реализовать?

Это модель нейронной сети, разработанная на основе статьи:

Другими словами, я не знаю, как инициализировать слои деконволюции (например, слой deconv1 в этом коде).

    import torch
    import torch.nn as nn
    import numpy as np 
    import torchvision
    import torch.nn.functional as F

    cfg = [64, 128, 256, 512, 512]

    class Deep3d(nn.Module):
        def __init__(self, in_channels=3, out_channels=3, device=torch.device('cpu')):
            super(Deep3d, self).__init__()
            self.device = device
            vgg16 = torchvision.models.vgg16_bn(pretrained=True)
            modules = []
            layer = []
            for l in vgg16.features:
                if isinstance(l, nn.MaxPool2d):
                    layer.append(l)
                    modules.append(layer)
                    layer = []
                else:
                    layer.append(l)

            scale = 1
            deconv = []
            layer = []
            for m in range(len(modules)):
                layer.append(nn.Conv2d(cfg[m], cfg[m], kernel_size=3, stride=1, padding=True))
                layer.append(nn.ReLU(inplace=True))

                layer.append(nn.Conv2d(cfg[m], cfg[m], kernel_size=3, stride=1, padding=True))
                layer.append(nn.ReLU(inplace=True))

                if(m==0):
                    layer.append(nn.ConvTranspose2d(cfg[m], 65, kernel_size=1, stride=1, padding=(0,0)))

                else:
                    scale *=2
                    layer.append(nn.ConvTranspose2d(cfg[m], 65, kernel_size=scale*2, stride=scale, padding=(scale//2, scale//2)))


                deconv.append(layer)           # add blocks of layers to deconv part of the network
                layer = []

            self.module_1 = nn.Sequential(*modules[0])
            self.module_2 = nn.Sequential(*modules[1])
            self.module_3 = nn.Sequential(*modules[2])
            self.module_4 = nn.Sequential(*modules[3])
            self.module_5 = nn.Sequential(*modules[4])

            self.deconv_1 = nn.Sequential(*deconv[0])
            self.deconv_2 = nn.Sequential(*deconv[1])
            self.deconv_3 = nn.Sequential(*deconv[2])
            self.deconv_4 = nn.Sequential(*deconv[3])
            self.deconv_5 = nn.Sequential(*deconv[4])

            self.linear_module = nn.Sequential(*[nn.Linear(15360,4096),          # hyperparam choice
                                                nn.ReLU(inplace=True),
                                                nn.Dropout(p=0.5),
                                                nn.Linear(4096,1950)])          # 1950=65(disparity range)*10*3(10*3 is feature map size)


            self.deconv_6 = nn.Sequential(*[nn.ConvTranspose2d(65,65,kernel_size=scale*2,stride=scale,padding=(scale//2,scale//2))])

            self.upconv_final = nn.Sequential(*[nn.ConvTranspose2d(65,65,kernel_size=(4,4),stride=2,padding=(1,1)),
                                                nn.ReLU(inplace=True),
                                                nn.Conv2d(65,65,kernel_size=(3,3),stride=1,padding=(1,1)),
                                                nn.Softmax(dim=1)])

            for block in [self.deconv_1,self.deconv_2,self.deconv_3,self.deconv_4,self.deconv_5,self.deconv_6,self.linear_module,self.upconv_final]:
                for m in block:
                    if isinstance(m, nn.Conv2d):
                        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                        if m.bias is not None:
                            nn.init.constant_(m.bias, 0)
                    elif isinstance(m, nn.BatchNorm2d):
                        nn.init.constant_(m.weight, 1)
                        nn.init.constant_(m.bias, 0)
                    elif isinstance(m, nn.Linear):
                        nn.init.normal_(m.weight, 0, 0.01)
                        nn.init.constant_(m.bias, 0)


        def forward(self, orig_x, x):

            x_copy = orig_x
            pred = []
            out_1 = self.module_1(x)
            out_2 = self.module_2(out_1)
            out_3 = self.module_3(out_2)
            out_4 = self.module_4(out_3)
            out_5 = self.module_5(out_4)
            # print(out_5.shape)

            out_5_flatten = out_5.view(x_copy.shape[0],-1)
            out_6 = self.linear_module(out_5_flatten)

            p1 = self.deconv_1(out_1)

            p2 = self.deconv_2(out_2)
            p3 = self.deconv_3(out_3)
            p4 = self.deconv_4(out_4)
            p5 = self.deconv_5(out_5)
            # print(p5.shape)
            p6 = self.deconv_6(out_6.view(x_copy.shape[0],65,3,10))

            pred.append(p1)
            pred.append(p2)
            pred.append(p3)
            pred.append(p4)
            pred.append(p5)
            pred.append(p6)

            out = torch.zeros(pred[0].shape).to(self.device)
            for p in pred:

                out = torch.add(out, p)

            out = self.upconv_final(out)               # to be elt wise multiplied with shifted left views


            out = F.interpolate(out,scale_factor=4,mode='bilinear') # upscale to match left input size


            new_right_image = torch.zeros(x_copy.size()).to(self.device)
            stacked_shifted_view = None
            stacked_out = None
            for depth_map_idx in range(-33,32):
                shifted_input_view = torch.zeros(x_copy.size()).to(self.device)

                if depth_map_idx<0:
                    shifted_input_view[:,:,:,:depth_map_idx] = x_copy[:,:,:,-depth_map_idx:]
                elif depth_map_idx==0:
                    shifted_input_view = x_copy
                else:

                    shifted_input_view[:,:,:,depth_map_idx:] = x_copy[:,:,:,:-depth_map_idx]


                if stacked_shifted_view is None:
                    stacked_shifted_view = shifted_input_view.unsqueeze(1)
                else:
                    stacked_shifted_view = torch.cat((stacked_shifted_view,shifted_input_view.unsqueeze(1)),dim=1)

                if stacked_out is None:
                    stacked_out = out[:,depth_map_idx+33:depth_map_idx+34,:,:].unsqueeze(1)
                else:
                    stacked_out = torch.cat((stacked_out,out[:,depth_map_idx+33:depth_map_idx+34,:,:].unsqueeze(1)),dim=1)

            softmaxed_stacked_shifted_view = stacked_shifted_view

            mult_soft_shift_out = torch.mul(stacked_out,softmaxed_stacked_shifted_view)

            final_rt_image = torch.sum(mult_soft_shift_out,dim=1)

            return final_rt_image





    # if(__name__=='__main__'):
    #     vgg16 = torchvision.models.vgg16(pretrained=True)
    #     model = Deep3d().to(torch.device('cpu'))
    #     out = model(torch.randn(10,3,384,1280),torch.randn(10,3,96,320))
...