Автоэнкодер Pytorch - Как улучшить потери? - PullRequest
0 голосов
/ 19 марта 2020

Ниже представлен автоэкодер в стиле UNET с фильтром, который я написал в Pytorch в конце. Сеть сходится быстрее, чем должна, и я не знаю почему. У меня есть набор данных из 4000 изображений, и я каждый раз собираю кадры 128x128. Я использую график тренировок и снижение веса. Я попытался поиграться с моими параметрами с помощью небольшого набора данных, чтобы увидеть улучшения, но, похоже, ничего не работает. Как только скорость обучения снижается, потеря просто отскакивает и не падает на пол, а в некоторых случаях снова возрастает. Моя сеть выглядит следующим образом:

import torch
import torch.nn as nn
from wiener_3d import wiener_3d
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import random

def np_to_pil(np_imgs):

img_num = np_imgs.shape[0]
channel_num = np_imgs.shape[1]
ar = np.clip(np_imgs*255, 0, 255).astype(np.uint8)

pil_imgs = []
for i in range(img_num):
    if channel_num == 1:
        img = ar[i][0]
    else:
        img = ar[i].transpose(1, 2, 0)
    pil_imgs.append(Image.fromarray(img))

return pil_imgs


class WienerFilter(nn.Module):
def __init__(self, param_b=16):
    super(WienerFilter, self).__init__()
    # self.register_parameter("param_a", nn.Parameter(torch.tensor(param_a)))
    # self.param_a = nn.Parameter(torch.tensor(param_a))
    # self.param_a.requires_grad = True
    self.param_b = param_b

def forward(self, input, std):
    tensors = input.shape[0]
    for i in range(tensors):
        tensor = input[i]
        tensor = torch.squeeze(tensor)
        # tensor = wiener_3d(tensor, self.param_a, self.param_b
        tensor = wiener_3d(tensor, 2*std, self.param_b)
        tensor = torch.unsqueeze(tensor, 0)
        input[i] = tensor
    return input


class AutoEncoder(nn.Module):
"""Autoencoder simple implementation """
def __init__(self):
    super(AutoEncoder, self).__init__()
    # Encoder
    # conv layer
    self.block1 = nn.Sequential(
        nn.Conv2d(1, 96, 3, padding=1),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1),
        nn.Conv2d(96, 96, 3, padding=1),
        nn.MaxPool2d(2),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1)

    )
    self.block2 = nn.Sequential(
        nn.Conv2d(96, 96, 3, padding=1),
        nn.MaxPool2d(2),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1)
    )
    self.block3 = nn.Sequential(
        nn.Conv2d(96, 96, 3, padding=1),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1),
        nn.ConvTranspose2d(96, 96, 2, 2),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1)
    )
    self.block4 = nn.Sequential(
        nn.Conv2d(192, 192, 3, padding=1),
        nn.BatchNorm2d(192),
        nn.LeakyReLU(0.1),
        nn.Conv2d(192, 192, 3, padding=1),
        nn.BatchNorm2d(192),
        nn.LeakyReLU(0.1),
        nn.ConvTranspose2d(192, 192, 2, 2),
        nn.BatchNorm2d(192),
        nn.LeakyReLU(0.1)
    )
    self.block5 = nn.Sequential(
        nn.Conv2d(288, 192, 3, padding=1),
        nn.BatchNorm2d(192),
        nn.LeakyReLU(0.1),
        nn.Conv2d(192, 192, 3, padding=1),
        nn.BatchNorm2d(192),
        nn.LeakyReLU(0.1),
        nn.ConvTranspose2d(192, 192, 2, 2),
        nn.BatchNorm2d(192),
        nn.LeakyReLU(0.1)
    )
    self.block6 = nn.Sequential(
        nn.Conv2d(193, 96, 3, padding=1),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1),
        nn.Conv2d(96, 64, 3, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.1),
        nn.Conv2d(64, 32, 3, padding=1),
        nn.LeakyReLU(0.1),
        nn.Conv2d(32, 1, 3, padding=1),
        nn.LeakyReLU(0.1)
    )

    self.wiener_filter = WienerFilter()

def forward(self, x, std):
    # torch.autograd.set_detect_anomaly(True)
    # print("input: ", x.shape)
    pool1 = self.block1(x)
    # print("pool1: ", pool1.shape)
    pool2 = self.block2(pool1)
    # print("pool2: ", pool2.shape)
    pool3 = self.block2(pool2)
    # print("pool3: ", pool3.shape)
    pool4 = self.block2(pool3)
    # print("pool4: ", pool4.shape)
    pool5 = self.block2(pool4)
    # print("pool5: ", pool5.shape)
    upsample5 = self.block3(pool5)
    # print("upsample5: ", upsample5.shape)
    concat5 = torch.cat((upsample5, pool4), 1)
    # print("concat5: ", concat5.shape)
    upsample4 = self.block4(concat5)
    # print("upsample4: ", upsample4.shape)
    concat4 = torch.cat((upsample4, pool3), 1)
    # print("concat4: ", concat4.shape)
    upsample3 = self.block5(concat4)
    # print("upsample3: ", upsample3.shape)
    concat3 = torch.cat((upsample3, pool2), 1)
    # print("concat3: ", concat3.shape)
    upsample2 = self.block5(concat3)
    # print("upsample2: ", upsample2.shape)
    concat2 = torch.cat((upsample2, pool1), 1)
    # print("concat2: ", concat2.shape)
    upsample1 = self.block5(concat2)
    # print("upsample1: ", upsample1.shape)
    concat1 = torch.cat((upsample1, x), 1)
    # print("concat1: ", concat1.shape)
    output = self.block6(concat1)
    path = "test"
    t_map = x - output

    filtering = self.wiener_filter(t_map, std) 

    filtered_output = output + filtering

    return filtered_output

Мои текущие параметры: оптимизатор Адама, спад скорости обучения на 0,1, если нет улучшения в течение 7 эпох, начальная скорость обучения 0,001, спад веса 0,0001, нет пакетов.

Я чувствую, что на этом этапе все перепробовал. Может кто-нибудь дать мне несколько советов о том, как улучшить мою сеть? Спасибо.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...