Я использую функции affine_grid
и grid_sample
для реализации билинейного модуля интерполяции ROIAlign. Но проблема в том, что только части результата верны по сравнению с torchvision.ops.ROIAlign
. Я действительно запутался и понятия не имею, что делать. Надеюсь, кто-нибудь поможет мне. Спасибо!
PS Моя версия факела - 1.3.1
. И нет необходимости смотреть каждую строку моего кода. Я протестировал AvgPool class
и часть BilinearIntp.forward()
для изменения размера и нашел, что они хорошо работают. Поэтому, возможно, вы могли бы сосредоточиться только на части affine_grid
в BilinearIntp.forward()
.
код:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import RoIAlign, RoIPool
from torch.autograd import Variable
class BilinearIntp(nn.Module):
def __init__(self, sampling_ratio, output_size, spatial_scale):
super().__init__()
self.sampling_ratio = sampling_ratio # sampling_ratio means sampling sampling_ratio x sampling_ratio points in a bin.
self.output_size = output_size # output_size means there are output_size[0] x output_size[1] bins
self.spatial_scale = spatial_scale
self.intp_mode = 'bilinear'
def forward(self, x, rois=None):
indices = rois[:,0].long()
x = x[indices]
k = x.shape[0]
c = x.shape[1]
H = x.shape[2]
W = x.shape[3]
x1 = rois[:,1::4] * self.spatial_scale
y1 = rois[:,2::4] * self.spatial_scale
x2 = rois[:,3::4] * self.spatial_scale
y2 = rois[:,4::4] * self.spatial_scale
roi_h = x2 - x1
roi_w = y2 - y1
#I calculate the offset to align the corner
offset_h = roi_h / (self.output_size[0] * self.sampling_ratio * 2)
offset_w = roi_w / (self.output_size[1] * self.sampling_ratio * 2)
x1 = x1 + offset_h
y1 = y1 + offset_w
x2 = x2 - offset_h
y2 = y2 - offset_w
zero = Variable(rois.data.new(k, 1).zero_())
# theta is the affine matrix.
# I get the following formula from http://www.telesens.co/2018/03/11/object-detection-and-classification-using-r-cnns/#ITEM-1455-4
theta = torch.cat([\
(x2 - x1) / (H - 1),
zero,
(x1 + x2 - H + 1) / (H - 1),
zero,
(y2 - y1) / (W - 1),
(y1 + y2 - W + 1) / (W - 1)], 1).view(-1, 2, 3)
grid = F.affine_grid(theta, torch.Size((k, c, self.sampling_ratio*self.output_size[0], self.sampling_ratio*self.output_size[1])), align_corners=True)
x = F.grid_sample(x, grid, mode=self.intp_mode, align_corners=True, padding_mode='zeros')
#resize x to do the pooling correctly
x = x.view(k, c, self.output_size[0], self.sampling_ratio, self.output_size[1], self.sampling_ratio)
x = x.transpose(3, 4).contiguous()
x = x.view(k, c*self.output_size[0]*self.output_size[1], self.sampling_ratio, self.sampling_ratio)
return x
class AvgPool(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.kernel_size = kernel_size
def forward(self, x):
x = F.avg_pool2d(x, self.kernel_size)
return x
if __name__ == '__main__':
output_size = (3,3)
spatial_scale = 1/4
sampling_ratio = 2
x = torch.randn(1,1,6,6)
rois = torch.tensor([
[0,1.0,6.6,6.7,10.1],
[0,4.4,10.2,11.5,14.5],
[0,0.89,10.7,2.5,13.5],
])
channel_num = x.shape[1]
roi_num = rois.shape[0]
a = RoIAlign(output_size, spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
b = BilinearIntp(sampling_ratio, output_size, spatial_scale)
bb = AvgPool(sampling_ratio)
ya = a(x, rois)
yb = bb(b(x, rois))
yb = yb.view(roi_num, channel_num, output_size[0], output_size[1])
torch.set_printoptions(precision=8)
print('torch:\n', ya)
print('my:\n', yb)
print('IsEqual: ', yb.equal(ya))
один случай вывода:
torch:
tensor([[[[-0.07067531, 0.05316655, 0.58660626],
[-0.26260251, -0.20688595, 0.31051391],
[-0.14635098, -0.09702750, 0.03639890]]],
[[[-0.14888579, -0.41584462, -0.86192513],
[-0.41730025, -0.98373175, -0.99377835],
[-0.73961210, -1.03835821, -0.57192719]]],
[[[-0.03776532, 0.01777636, 0.01046431],
[-0.01859918, -0.06157728, -0.16563404],
[-0.08040507, -0.33431917, -0.57920480]]]])
my:
tensor([[[[-0.04292072, 0.08833585, 0.60396326],
[-0.28439966, -0.22748421, 0.36191046],
[-0.18267949, -0.13135812, 0.12205967]]],
[[[-0.14888588, -0.41584507, -0.86192548],
[-0.41730040, -0.98373216, -0.99377829],
[-0.73961198, -1.03835797, -0.57192695]]],
[[[-0.06950454, -0.04870241, -0.02790027],
[-0.01129269, 0.00485144, 0.02099557],
[-0.00545665, -0.04398080, -0.08250497]]]])
Мы можем узнать, что средняя часть выходных данных одинакова, хотя есть потеря точности, что действительно странно.