Прямо сейчас я пытаюсь внедрить U-net для рисования изображений. Я получил свой код для запуска без ошибок, и при распечатке потерь в поездах и потерь при проверке он действительно сходится к некоторой точке. Изображения действительно напоминают правду о земле, но их качество довольно низкое. Когда я смотрю на значения пикселей в неокрашенной части, это далеко не сходится к истинным значениям. Кажется, будто он случайно меняет свои значения.
С другой стороны, я смог выполнить ту же задачу с Matconvnet, библиотекой MATLAB, так что, думаю, она должна работать и в pytorch.
Так что проблема в том, что я думаю, что у меня есть проблемы с реализацией U-net с pytorch. Тем не менее, Я не знаю, где я ошибаюсь в реализации, так как я не получаю никаких ошибок. Ниже моя реализация.
Я попытался создать заново, но это не сработало. Сначала я подумал, что это может быть проблема с гиперпараметрами или методами инициализации, поэтому я попытался также переключить их, но ни один из них не помог.
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNetConvBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
super(UNetConvBlock, self).__init__()
self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
self.bn = nn.BatchNorm2d(out_size)
self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
self.bn2 = nn.BatchNorm2d(out_size)
self.activation = activation
def forward(self, x):
x1 = self.activation(self.bn(self.conv(x)))
out = self.activation(self.bn2(self.conv2(x1)))
return out
class UNetLastBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
super(UNetLastBlock, self).__init__()
self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
self.bn = nn.BatchNorm2d(out_size)
self.conv2 = nn.Conv2d(out_size,in_size, kernel_size, padding=1)
self.bn2 = nn.BatchNorm2d(in_size)
self.activation = activation
def forward(self, x):
x1 = self.activation(self.bn(self.conv(x)))
out = self.activation(self.bn2(self.conv2(x1)))
return out
class UNetUpBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
super(UNetUpBlock, self).__init__()
self.up = nn.ConvTranspose2d(in_size, in_size, 2, stride=2)
# Due to concat
self.conv = nn.Conv2d(in_size * 2, in_size, kernel_size, padding=1)
self.bn = nn.BatchNorm2d(in_size)
self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
self.bn2 = nn.BatchNorm2d(out_size)
self.activation = activation
def forward(self, x, bridge):
up = self.up(x)
out = torch.cat([up, bridge], dim=1)
out = self.activation(self.bn(self.conv(out)))
out = self.activation(self.bn2(self.conv2(out)))
return out
class UNet(nn.Module):
def __init__(self, in_c, out_c):
super(UNet, self).__init__()
self.activation = F.relu
self.pool1 = nn.MaxPool2d(2)
self.pool2 = nn.MaxPool2d(2)
self.pool3 = nn.MaxPool2d(2)
self.pool4 = nn.MaxPool2d(2)
self.conv_block2_64 = UNetConvBlock(in_c, 64)
self.conv_block64_128 = UNetConvBlock(64, 128)
self.conv_block128_256 = UNetConvBlock(128, 256)
self.conv_block256_512 = UNetConvBlock(256, 512)
self.conv_block512_1024 = UNetLastBlock(512, 1024)
self.up_block1024_512 = UNetUpBlock(512, 256)
self.up_block512_256 = UNetUpBlock(256, 128)
self.up_block256_128 = UNetUpBlock(128, 64)
self.up_block128_64 = UNetUpBlock(64, 64)
self.last = nn.Conv2d(64, out_c, 1)
def forward(self, x):
block1 = self.conv_block2_64(x)
pool1 = self.pool1(block1)
block2 = self.conv_block64_128(pool1)
pool2 = self.pool2(block2)
block3 = self.conv_block128_256(pool2)
pool3 = self.pool3(block3)
block4 = self.conv_block256_512(pool3)
pool4 = self.pool4(block4)
block5 = self.conv_block512_1024(pool4)
up1 = self.up_block1024_512(block5, block4)
up2 = self.up_block512_256(up1, block3)
up3 = self.up_block256_128(up2, block2)
up4 = self.up_block128_64(up3, block1)
return self.last(up4)