При обучении в 3D Unet появился странный спад потерь (похожий на «обрыв») со взвешенной кросс-энтропийной потерей или потерей очков в кости - PullRequest
0 голосов
/ 14 октября 2018

Показатель потери кости:

The figure of dice loss

Показатель взвешенной кросс-энтропийной потери:

The figure of weighted cross-entropy loss

Оптимизатором является Адам, lr = 0,0002, бета1 = 0,5, бета2 = 0,999.У кого-нибудь есть такая же проблема, как у меня?Можете ли вы сказать мне решение и причину этого?

Модели 3d-unet показывают следующее.

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

class UNet3D(nn.Module):
    def name(self):
        return self.name
    def __init__(self, name, in_channels, out_channels, min_filters=2, norm_layer=nn.BatchNorm3d):
        super(UNet3D, self).__init__()
        self.name = name

        activation = nn.LeakyReLU(0.2, True) 

        self.input = Conv2Block(in_channels, min_filters, activation=activation, norm_layer=norm_layer)

        self.down1 = DownBlock(min_filters, min_filters*2, activation=activation, norm_layer=norm_layer)
        self.down2 = DownBlock(min_filters*2, min_filters*4, activation=activation, norm_layer=norm_layer)
        self.down3 = DownBlock(min_filters*4, min_filters*8, activation=activation, norm_layer=norm_layer)

        self.up1 = UpBlock(min_filters*8, min_filters*4, activation=activation, norm_layer=norm_layer)
        self.up2 = UpBlock(min_filters*4, min_filters*2, activation=activation, norm_layer=norm_layer)
        self.up3 = UpBlock(min_filters*2, min_filters, activation=activation, norm_layer=norm_layer)

        if self.name == 'netS':
            self.out = OutBlock(min_filters, out_channels, is_image=False)
        elif self.name == 'netG':
            self.out = OutBlock(min_filters, out_channels, is_image=True)

    def forward(self, x):
        # print('x', np.unique(x))
        x1 = self.input(x)
        # print('input', np.unique(x1.detach().numpy()))
        x2 = self.down1(x1)
        # print('down1', np.unique(x2.detach().numpy()))
        x3 = self.down2(x2)
        # print('down2', np.unique(x3.detach().numpy()))
        x = self.down3(x3)
        # print('down3', np.unique(x.detach().numpy()))

        x = self.up1(x, x3)
        # print('up1', np.unique(x.detach().numpy()))
        x = self.up2(x, x2)   
        # print('up2', np.unique(x.detach().numpy()))
        x = self.up3(x, x1)
        # print('up3', np.unique(x.detach().numpy()))

        x = self.out(x)
        # print('out', np.unique(x.detach().numpy()))

        return x

class Conv2Block(nn.Module):
    '''
    Two successive Conv3d, each Conv3d followed by norm_layer and activation.
        first one| in: in_channels, out: out_channels
        second one| in: out_channels, out: out_channels
    '''
    def __init__(self, in_channels, out_channels, activation, norm_layer=nn.BatchNorm3d, kernel_size=3):
        super(Conv2Block, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size, padding=int(np.ceil((kernel_size-1)/2)))
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size, padding=int(np.ceil((kernel_size-1)/2)))

        self.norm_layer = norm_layer
        if(norm_layer != None):
            self.bn1 = norm_layer(out_channels)
            self.bn2 = norm_layer(out_channels)

        self.activation = activation

    def forward(self, x):
        if(self.norm_layer != None):
            x = self.activation(self.bn1(self.conv1(x)))
            x = self.activation(self.bn2(self.conv2(x)))
        else:
            x = self.activation(self.conv1(x))
            x = self.activation(self.conv2(x))
        return x

class DownBlock(nn.Module):
    '''
    MaxPool3d + Conv2Block
    '''
    def __init__(self, in_channels, out_channels, activation, norm_layer=nn.BatchNorm3d, kernel_size=3):
        super(DownBlock, self).__init__()
        self.conv = Conv2Block(in_channels, out_channels, activation, norm_layer=norm_layer, kernel_size=kernel_size)

    def forward(self, x):
        max_pool = nn.MaxPool3d(2)
        x = self.conv(max_pool(x))
        return x

class UpConv(nn.Module):
    '''
    interpolate + conv3d + activation
    '''
    def __init__(self, in_channels, out_channels, activation, kernel_size=3):
        super(UpConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=int(np.ceil((kernel_size-1)/2)))
        self.activation = activation

    def forward(self, x, sz):
        x = F.interpolate(x, sz, mode='trilinear', align_corners=True)
        # x = F.interpolate(x, sz, mode='nearest')

        x = self.activation(self.conv(x))           # TODO 未作normalization
        return x

class UpBlock(nn.Module):
    '''
    UpConv + cat + Conv2Block
    '''
    def __init__(self, in_channels, out_channels, activation, norm_layer=nn.BatchNorm3d, kernel_size=3):
        super(UpBlock, self).__init__()
        self.upconv = UpConv(in_channels, out_channels, activation, kernel_size=kernel_size)
        self.conv = Conv2Block(in_channels, out_channels, activation, norm_layer=norm_layer, kernel_size=kernel_size)

    def forward(self, x, x2):
        x = self.upconv(x, (x2.shape[-3], x2.shape[-2], x2.shape[-1]))
        x = torch.cat([x, x2], dim=1)
        x = self.conv(x)
        return x

class OutBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_image):
        super(OutBlock, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, 1)

        self.out = nn.Tanh()

        if is_image:
            self.out = nn.Tanh()
        else:
            self.out = nn.Sigmoid()

    def forward(self, x):
        x = self.conv(x)

        x = self.out(x)
        return x

Фрагменты взвешенной кросс-энтропийной потери и оценки костей.

class SoftDiceLoss(nn.Module):
    def __init__(self):
        super(SoftDiceLoss, self).__init__()

    def forward(self, preds, labels):
        preds = F.softmax(preds, dim=1)

        num = labels.size(0)
        m1 = preds.view(num, -1)
        m2 = labels.view(num, -1)
        intersection = (m1 * m2)

        score = 2. * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1)  # m1.sum = cube size
        score = 1 - score.sum() / num
        return score

class WeightedCrossEntropyLoss(nn.Module):
    '''
    Negative log likelihood loss
    '''   
    def __init__(self):
        super(WeightedCrossEntropyLoss, self).__init__()
        # self.loss = nn.NLLLoss()

    def forward(self, pred, gt, alpha=3.0):
        gt = torch.squeeze(gt.argmax(dim=1), dim=1)

        assert gt.dtype == torch.long    
        import torch.nn.functional as F
        mtx = F.cross_entropy(pred, gt, reduction='none')

        bg = (gt == 0) + (gt == 5)    # background

        neg = mtx[bg]
        pos = mtx[1-bg]

        Np, Nn = pos.numel(), neg.numel()

        pos = pos.sum()

        k = min(Np*alpha, Nn)
        if k > 0:
            neg, _ = torch.topk(neg, int(k))
            neg = neg.sum()
        else:
            neg = 0.0


        loss = (pos + neg)/(Np + k)

        return loss
...