Почему мой полностью сверточный автоэнкодер не симметричен? - PullRequest
2 голосов
/ 02 октября 2019

Я разрабатываю полностью сверточный автоэнкодер, который принимает 3 канала на вход и выводит 2 канала (in: LAB, out: AB). Поскольку выходной размер должен совпадать с входным, я использую Full Convolution.

Код:

import torch.nn as nn


class AE(nn.Module):
   def __init__(self):
       super(AE, self).__init__()

        self.encoder = nn.Sequential(
           # conv 1
           nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(64),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),

           # conv 2
           nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(128),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),

           # conv 3
           nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(256),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),

           # conv 4
           nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(512),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),

           # conv 5
           nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(1024),
           nn.ReLU()

       )

       self.decoder = nn.Sequential(
           # conv 6
           nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(512),
           nn.ReLU(),

           # conv 7
           nn.Upsample(scale_factor=2, mode='bilinear'),
           nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(256),
           nn.ReLU(),

           # conv 8
           nn.Upsample(scale_factor=2, mode='bilinear'),
           nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(128),
           nn.ReLU(),

           # conv 9
           nn.Upsample(scale_factor=2, mode='bilinear'),
           nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(64),
           nn.ReLU(),

           # conv 10 out
           nn.Upsample(scale_factor=2, mode='bilinear'),
           nn.ConvTranspose2d(in_channels=64, out_channels=2, kernel_size=5, stride=1, padding=1),
           nn.Softmax()    # multi-class classification

           # TODO softmax deprecated
       )

   def forward(self, x):
       x = self.encoder(x)
       x = self.decoder(x)
       return x

Размер тензора вывода должен быть: torch.Size ([1, 2, 199, 253])

Размер, который имеет тензор вывода действительно : torch.Size ([1, 2, 190, 238])

Моя основная проблема - объединить Conv2d и MaxPool2d и установить правильные значения параметров в ConvTranspose2d. Из-за этого я рассматриваю их отдельно, используя функцию Upsample для MaxPool2d и ConvTranspose2d только для Conv2d. Но у меня все еще есть небольшая асимметрия, и я действительно не знаю почему.

Спасибо за помощь!

1 Ответ

1 голос
/ 02 октября 2019

Есть две проблемы.

Первый - недостаточное заполнение: при kernel_size=5 ваши свертки уменьшают изображение на 4 каждый раз, когда они применяются (по 2 пикселя с каждой стороны), поэтому вам нужно padding=2и не только 1, во всех местах.

Второй - это "неравномерный" размер ввода. Я имею в виду, что, как только ваши свертки будут правильно заполнены, у вас останутся операции понижающей дискретизации, которые в каждой точке пытаются разделить разрешение вашего изображения пополам. Когда они терпят неудачу, они просто возвращают меньший результат (целочисленное деление отбрасывает остаток). Поскольку в вашей сети 4 последовательных 2-кратных операции понижающей дискретизации, ваш вход должен иметь размеры H, W, кратные 2^4=16. Тогда вы действительно получите одинаковую форму. Пример ниже

import torch
import torch.nn as nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()

        self.encoder = nn.Sequential(
            # conv 1
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # conv 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # conv 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # conv 4
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # conv 5
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            # conv 6
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            # conv 7
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            # conv 8
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            # conv 9
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            # conv 10 out
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=64, out_channels=2, kernel_size=5, stride=1, padding=2),
            nn.Softmax()    # multi-class classification
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

input = torch.randn(1, 3, 6*16, 7*16)
output = AE()(input)
print(input.shape)
print(output.shape)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...