Я просто пытался провести эксперимент с использованием PyTorch, где я пытаюсь вычислить матрицу аффинного преобразования из заданной пары изображений (исходное и преобразованное изображение). Для этого примера я просто использую небольшую сетку 5x5 с прямой линией в качестве исходного изображения и линией, наклоненной на 45 градусов в качестве преобразованного выходного сигнала. По какой-то причине кажется, что потери снижаются, а градиенты становятся все меньше и меньше (очевидно). Но решение, к которому оно сходится, похоже, далеко (полностью не похоже на прямую линию).
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(989)
# source_image = torch.tensor([[0,1,0],[0,1,0],[0,1,0]])
source_image = torch.tensor([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0]])
plt.imshow(source_image)
# transformed_image = torch.eye(3)
transformed_image = torch.eye(5)
plt.imshow(transformed_image)
source_image = source_image.reshape(1, 1, source_image.shape[0], source_image.shape[1])
transformed_image = transformed_image.reshape(1, 1, transformed_image.shape[0], transformed_image.shape[1])
source_image = source_image.type(torch.FloatTensor)
transformed_image = transformed_image.type(torch.FloatTensor)
class AffineNet(nn.Module):
def __init__(self):
super(AffineNet, self).__init__()
self.M = torch.nn.Parameter(torch.randn(1, 2, 3))
def forward(self, im):
flow_grid = F.affine_grid(self.M, transformed_image.size())
transformed_flow_image = F.grid_sample(transformed_image, flow_grid, padding_mode="border")
return transformed_flow_image
affineNet = AffineNet()
optimizer = optim.SGD(affineNet.parameters(), lr=0.01)
criterion = nn.MSELoss()
for i in range(1000):
optimizer.zero_grad()
output = affineNet(transformed_image)
loss = criterion(output, source_image)
loss.backward()
if(i%10==0):
print(i, loss.item(), affineNet.M.grad)
optimizer.step()
print(affineNet.M)
printme = output.detach().reshape(output.shape[2], output.shape[3])
plt.imshow(printme.cpu())
Кажется, это работает нормально, если вы возитесь с закомментированными строками и используете сетку 3х3, а не 5х5. Может кто-нибудь помочь мне понять, почему это происходит? Кажется, это сильно изменится, если я поиграю с семенем.