Почему модель pytorch работает плохо после установки eval ()? - PullRequest
0 голосов
/ 07 ноября 2019

Я использовал pytorch для построения модели сегментации, которая использует слой BatchNormalization. Я обнаружил, что когда я установил model.eval() в тесте, результат теста будет равен 0. Если я не установлю model.eval(), он будет работать хорошо.

Я пытался найти похожие вопросы, ноЯ пришел к выводу, что model.eval() может исправить параметры BN, но я все еще не понимаю, как решить эту проблему.

Мой размер пакета равен 1, и это моя модель:

import torch
import torch.nn as nn


class Encode_Block(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Encode_Block, self).__init__()

        self.conv1 = Res_Block(in_feat, out_feat)
        self.conv2 = Res_Block_identity(out_feat, out_feat)

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class Decode_Block(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Decode_Block, self).__init__()

        self.conv1 = Res_Block(in_feat, out_feat)
        self.conv2 = Res_Block_identity(out_feat, out_feat)

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class Conv_Block(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Conv_Block, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
        )

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        return outputs


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
    )


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Res_Block(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(Res_Block, self).__init__()
        self.conv_input = conv1x1(inplanes, planes)
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.LeakyReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.conv3 = conv1x1(planes, planes)
        self.stride = stride

    def forward(self, x):
        residual = self.conv_input(x)

        out = self.conv1(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn(out)

        out += residual
        out = self.relu(out)

        return out


class Res_Block_identity(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(Res_Block_identity, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.LeakyReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.conv3 = conv1x1(planes, planes)
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn(out)

        out += residual
        out = self.relu(out)

        return out


class UpConcat(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(UpConcat, self).__init__()

        self.de_conv = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2)

    def forward(self, inputs, down_outputs):
        outputs = self.de_conv(inputs)
        out = torch.cat([down_outputs, outputs], 1)
        return out


class Res_UNet(nn.Module):
    def __init__(self, num_channels=1, num_classes=1):
        super(Res_UNet, self).__init__()
        flt = 64
        self.down1 = Encode_Block(num_channels, flt)
        self.down2 = Encode_Block(flt, flt * 2)
        self.down3 = Encode_Block(flt * 2, flt * 4)
        self.down4 = Encode_Block(flt * 4, flt * 8)
        self.down_pool = nn.MaxPool2d(kernel_size=2)
        self.bottom = Encode_Block(flt * 8, flt * 16)
        self.up_cat1 = UpConcat(flt * 16, flt * 8)
        self.up_conv1 = Decode_Block(flt * 16, flt * 8)
        self.up_cat2 = UpConcat(flt * 8, flt * 4)
        self.up_conv2 = Decode_Block(flt * 8, flt * 4)
        self.up_cat3 = UpConcat(flt * 4, flt * 2)
        self.up_conv3 = Decode_Block(flt * 4, flt * 2)
        self.up_cat4 = UpConcat(flt * 2, flt)
        self.up_conv4 = Decode_Block(flt * 2, flt)
        self.final = nn.Sequential(
            nn.Conv2d(flt, num_classes, kernel_size=1), nn.Sigmoid()
        )

    def forward(self, inputs):
        down1_feat = self.down1(inputs)
        pool1_feat = self.down_pool(down1_feat)
        down2_feat = self.down2(pool1_feat)
        pool2_feat = self.down_pool(down2_feat)
        down3_feat = self.down3(pool2_feat)
        pool3_feat = self.down_pool(down3_feat)
        down4_feat = self.down4(pool3_feat)
        pool4_feat = self.down_pool(down4_feat)

        bottom_feat = self.bottom(pool4_feat)

        up1_feat = self.up_cat1(bottom_feat, down4_feat)
        up1_feat = self.up_conv1(up1_feat)
        up2_feat = self.up_cat2(up1_feat, down3_feat)
        up2_feat = self.up_conv2(up2_feat)
        up3_feat = self.up_cat3(up2_feat, down2_feat)
        up3_feat = self.up_conv3(up3_feat)
        up4_feat = self.up_cat4(up3_feat, down1_feat)
        up4_feat = self.up_conv4(up4_feat)

        outputs = self.final(up4_feat)

        return outputs

Модель полностью не сегментируется после установки model.eval(), но модель исправна после удаления model.eval(). Я запутался в этом, и необходим ли model.eval() в тесте?

1 Ответ

1 голос
/ 08 ноября 2019

Слои BatchNorm сохраняют оценки его вычисленного среднего значения и дисперсии во время обучения model.train(), которые затем используются для нормализации во время оценки model.eval().

Каждый слой имеет свою собственную статистику среднего значения и дисперсии своих выходов / активаций. Поскольку вы многократно используете слой BatchNorm self.bn = nn.BatchNorm2d(planes) несколько раз, статика смешивается и не отражает фактическое среднее значение и дисперсию. Поэтому вы должны создавать новый слой BatchNorm для каждого его использования.

РЕДАКТИРОВАТЬ: я только что прочитал, что ваш batch_size равен 1, что также может быть ядром вашей проблемы: см. Tensorflow и Batch Normalizationс размером партии == 1 => выводит все нули

...