У меня есть функции кодирования-декодирования.Я не могу декодировать свое изображение, чтобы окончательный вывод совпадал с размером ввода, который составляет [5,3,32,32].Как восстановить изображение в декодере так, чтобы размеры входного и выходного изображения совпадали?Пожалуйста, heeeelp !!!
from torch import nn
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
class UnFlatten(nn.Module):
def forward(self, input, size=512):
return input.view(input.size(0), size, 1, 1)
net = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2),
nn.ReLU(),
#nn.Conv2d(128, 256, kernel_size=4, stride=2),
#nn.ReLU(),
Flatten(),
nn.Linear(512, 32),
)
net2= nn.Sequential(
nn.Linear(32, 512),
UnFlatten(),
nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, kernel_size=6, stride=2),
nn.Sigmoid(),
)
input = torch.zeros(5,3,32,32)
mu=net(input)
print("mu shape", mu.shape)
mu2= net2(mu)
print("mu2 shape", mu2.shape)