Pytorch unet semanti c сегментация - PullRequest
1 голос
/ 16 января 2020

У меня есть версия Tensorflow v1 unet, которая очень хорошо тренируется с использованием SGD и скорости обучения 0,05.

Я переписал сеть в Pytorch, поскольку хочу показать некоторые функции, которые не так просто в Tensorflow.

Моя модель постоянно предсказывает пустую маску, поэтому я попытался наложить модель на одно изображение.

Возможно наложение одного примера изображения для прогнозирования одной маски , но это работает только с Адамом, скорость обучения 0,0005 и 1000 эпох. Моя старая модель может сделать это за 10 эпох или около того.

Я не вижу ничего очевидного, что я делаю неправильно. Я, должно быть, что-то делаю неправильно, так как это тривиальная проблема, которая требует небольшой настройки.

import numpy as np
import cv2
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, relu=True):
        super().__init__()
        if relu:
            self.double_conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )
        else:
            self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True, relu=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels, relu=relu)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

def decode_segmap(image, num_classes=3):

  label_colors = np.array([(128, 0, 0),
               (0, 128, 0), (0, 0, 128)])

  r = np.zeros_like(image).astype(np.uint8)
  g = np.zeros_like(image).astype(np.uint8)
  b = np.zeros_like(image).astype(np.uint8)

  for l in range(0, num_classes):
    idx = image == l
    r[idx] = label_colors[l, 0]
    g[idx] = label_colors[l, 1]
    b[idx] = label_colors[l, 2]

  rgb = np.stack([r, g, b], axis=2)
  return rgb


def load_batch(batch_size):
    rotated_frame = Image.open('0test.png')
    rotated_gt = Image.open('0label.png')

    trf = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean = [0.2455],  std = [0.2684])])

    rotated_frame = trf(rotated_frame).unsqueeze(0)

    trf = transforms.Compose([
                    transforms.ToTensor()])
    rotated_gt = trf(rotated_gt).unsqueeze(0)

    rotated_frame = torch.mean(rotated_frame, 1).unsqueeze(1)
    rotated_gt = torch.mean(rotated_gt, 1).unsqueeze(1)

    return rotated_frame.to(device), rotated_gt.type(torch.long).to(device).squeeze(1)


net = UNet(1, 3)
net.to(device=device)

# Loss
#optimizer = optim.RMSprop(net.parameters(), lr=0.005, weight_decay=1e-8)
optimizer = optim.SGD(net.parameters(), lr=0.0005)
#optimizer = optim.Adam(net.parameters(), lr=0.0005)

criterion = nn.CrossEntropyLoss()

# Load data
rotated_frame, rotated_gt = load_batch(1)
print(rotated_frame.shape)
print(rotated_gt.shape)

# Train
epochs = 1000
losses = [] 
for epoch in range(epochs):
    predicted = net(rotated_frame)
    loss = criterion(predicted, rotated_gt)
    losses.append(loss)
    loss.backward()
    optimizer.step()
    print('Epoch {}/{} Loss: {}'.format(epoch, epochs, loss))

output = torch.argmax(predicted.squeeze(), dim=0).detach().cpu().numpy()

a, b = np.min(output), np.max(output)
print('Predicted: min: {} max: {}'.format(a, b))
print(output.shape)
rgb = decode_segmap(output)
plt.imshow(rgb)
plt.savefig('predicted_argmaxed.png')

gt = rotated_gt.squeeze().detach().cpu().numpy()
a, b = np.min(gt), np.max(gt)
print('Gt: min: {} max: {}'.format(a, b))
rgb = decode_segmap(gt)
plt.imshow(rgb)
plt.savefig('gt_argmaxed.png')

Примеры изображений здесь:

0test.png 0label.png

Буду признателен за любую помощь!

1 Ответ

0 голосов
/ 16 января 2020

Если вы используете CrossEntropyLoss пытались ли вы добавить весовые коэффициенты для классов?

weights = torch.tensor([0.75, 1], dtype=torch.float)
criterion = torch.nn.CrossEntropyLoss(weight=weights,
                                      reduction='none').to(device)

Если ваша модель выдает пустую маску (например, белую), теоретически она сводит к минимуму потери, поскольку наличие полностью белого изображения кажется наиболее заметным классом, в зависимости от количества классов, которые вы пытаетесь добавить весу к классу границ.

Веса, которые вы видите там, я использовал, когда был делать бинарную классификацию, когда класс равен 70%, а остальные 30%.

В противном случае BN, как упоминал Натт, также может помочь. Ваша скорость обучения также кажется слишком низкой.

Редактировать : Просто для пояснения, из документов:

weight (Tensor, optional) – a manual rescaling weight given to each class. If given, has to be a Tensor of size C
reduction (string, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: 'mean'
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...