Автоэнкодер для восстановления изображения выдает серое изображение со странной сеткой - PullRequest
1 голос
/ 15 апреля 2020

Я пытаюсь создать Deep Fake, используя авто-кодер. Я использую один кодер и два декодера: один для целевого изображения, а другой для исходного изображения (целевое лицо - это лицо, которое я хочу «вставить» на голову источника). Итак, сначала я пытаюсь обучить кодировщик и оба декодера восстановить входные грани (300, 300, 3). Но для обоих декодеров на выходе получается не цветное изображение, а серое, потому что для каждого пикселя значения Red, Green и Blue почти одинаковы. Кроме того, на выходных изображениях есть странная сетка 3х3: input image and produced image

Я использую размер пакета 1, потому что я не знаю, как это сделать с мини-пакетами в этот случай (но это еще одна проблема). Я также использую остаточные соединения, которые улучшили качество. Последний слой имеет сигмовидную активацию (что может быть неправильно). Моя потеря - Бинарная Кросс-Энтропия, а оптимизатор - Адам. Скорость обучения составляет 0,001 (я также пробовал 0,0001 и 0,00075).

Вот моя модель:

import matplotlib.pyplot as plt
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F




class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        """ encoder """
        self.conv1 = nn.Conv2d(3, 32, kernel_size=(4, 4))
        self.batchnorm1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=(4, 4))
        self.batchnorm2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3))
        self.batchnorm3 = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(128, 256, kernel_size=(4, 4))
        self.batchnorm4 = nn.BatchNorm2d(256)

        self.maxpool3x3 = nn.MaxPool2d(3)
        self.maxpool2x2 = nn.MaxPool2d(2)

        """ target-decoder """
        self.targetDeconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(4, 4))
        self.targetBatchnorm1 = nn.BatchNorm2d(128)

        self.targetDeconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(3, 3))
        self.targetBatchnorm2 = nn.BatchNorm2d(64)

        self.targetDeconv3 = nn.ConvTranspose2d(64, 32, kernel_size=(4, 4))
        self.targetBatchnorm3 = nn.BatchNorm2d(32)

        self.targetDeconv4 = nn.ConvTranspose2d(32, 3, kernel_size=(4, 4))

        self.upsample3x3 = nn.Upsample(scale_factor=3)
        self.upsample2x2 = nn.Upsample(scale_factor=2)

        """ source-decoder """
        self.sourceDeconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(4, 4))
        self.sourceBatchnorm1 = nn.BatchNorm2d(128)

        self.sourceDeconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(3, 3))
        self.sourceBatchnorm2 = nn.BatchNorm2d(64)

        self.sourceDeconv3 = nn.ConvTranspose2d(64, 32, kernel_size=(4, 4))
        self.sourceBatchnorm3 = nn.BatchNorm2d(32)

        self.sourceDeconv4 = nn.ConvTranspose2d(32, 3, kernel_size=(4, 4))

        self.upsample3x3 = nn.Upsample(scale_factor=3)
        self.upsample2x2 = nn.Upsample(scale_factor=2)

    def _visualize_features(self, feature_maps, dim: tuple=(), title: str=""):
        try:
            x, y = dim
            fig, axs = plt.subplots(x, y)
            c = 0
            for i in range(x):
                for j in range(y):
                    axs[i][j].matshow(feature_maps.detach().cpu().numpy()[0][c])
                    c += 1

            fig.suptitle(title)
            plt.show()

        except:
            pass

    def forward(self, x, label: str="0", visualize: bool=False):
        """ encoder """
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = F.relu(x)
        x_1 = self.maxpool3x3(x)

        if visualize: print(x_1.shape); self._visualize_features(x_1, dim=(4, 4))

        x = self.conv2(x_1)
        x = self.batchnorm2(x)
        x = F.relu(x)
        x_2 = self.maxpool3x3(x)

        if visualize: print(x_2.shape); self._visualize_features(x_2, dim=(4, 4))

        x = self.conv3(x_2)
        x = self.batchnorm3(x)
        x = F.relu(x)
        x_3 = self.maxpool2x2(x)

        if visualize: print(x_3.shape); self._visualize_features(x_3, dim=(4, 4))

        x = self.conv4(x_3)
        x = self.batchnorm4(x)
        x = F.relu(x)
        x = self.maxpool2x2(x)

        if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))


        """ target-decoder """
        if label == "0":
            x = self.upsample2x2(x)
            x = self.targetDeconv1(x)
            x += x_3
            x = self.targetBatchnorm1(x)
            x = F.relu(x)

            if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))

            x = self.upsample2x2(x)
            x = self.targetDeconv2(x)
            x += x_2
            x = self.targetBatchnorm2(x)
            x = F.relu(x)

            if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))

            x = self.upsample3x3(x)
            x = self.targetDeconv3(x)
            x += x_1
            x = self.targetBatchnorm3(x)
            x = F.relu(x)

            if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))

            x = self.upsample3x3(x)
            x = self.targetDeconv4(x)
            x = torch.sigmoid(x)

            if visualize: print(x.shape); self._visualize_features(x, dim=(3, 1))

            return x

        """ source-decoder """
        if label == "1":
            x = self.upsample2x2(x)
            x = self.sourceDeconv1(x)
            x += x_3
            x = self.sourceBatchnorm1(x)
            x = F.relu(x)

            if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))

            x = self.upsample2x2(x)
            x = self.sourceDeconv2(x)
            x += x_2
            x = self.sourceBatchnorm2(x)
            x = F.relu(x)

            if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))

            x = self.upsample3x3(x)
            x = self.sourceDeconv3(x)
            x += x_1
            x = self.sourceBatchnorm3(x)
            x = F.relu(x)

            if visualize: print(x.shape); self._visualize_features(x, dim=(4, 4))

            x = self.upsample3x3(x)
            x = self.sourceDeconv4(x)
            x = torch.sigmoid(x)

            if visualize: print(x.shape); self._visualize_features(x, dim=(3, 1))

            return x

Итак, мои проблемы: Почему выходное изображение серое? Что это за сетка на моем выходном изображении? Обе проблемы связаны? Я проверил, имеет ли это какое-либо отношение к rgb против bgr, но это не похоже на это. Я надеюсь, что кто-нибудь может решить мою проблему, спасибо заранее:)

...