Я думаю, что ваша проблема в том, что вы используете torch.nn.functional
вместо torch
. Целью функционального API является выполнение операций (в данном случае conv2d) напрямую, без создания экземпляра класса и последующего вызова его метода forward. Следовательно, утверждение self.conv1 = F.conv2d(inputs,kernela,stride=1,padding=1)
уже выполняет свертку между input
и kernela
, и то, что у вас есть в self.conv1
, является результатом такой свертки. Здесь есть два подхода для решения проблемы. Используйте torch.Conv2d
внутри __init__
, где inputs
- это канал входа, а не тензор с той же формой, что и ваш реальный вход. И второй подход - придерживаться функционального API, но перенести его на метод forward()
. То, чего вы хотите достичь, можно сделать, изменив форвард на:
def forward(self, x ):
print(x.shape)
G_x = F.conv2d(x,self.kernela,stride=1,padding=1)
G_y = F.conv2d(x,self.kernelb,stride=1,padding=1)
out = torch.sqrt(torch.pow(G_x,2)+ torch.pow(G_y,2))
return out
Обратите внимание, что я создал kernela
и kernelb
атрибуты класса. Таким образом, вы также должны изменить __init__()
на
def __init__(self):
super(SobelFilter, self).__init__()
kernel1=torch.Tensor([[1, 0, -1],[2,0,-2],[1,0,-1]])
self.kernela=kernel1.expand((1,1,3,3))
kernel2=torch.Tensor([[1, 2, 1],[0,0,0],[-1,-2,-1]])
self.kernelb=kernel2.expand((1,1,3,3))