функция conv2d в pytorch - PullRequest
       35

функция conv2d в pytorch

1 голос
/ 05 мая 2019

Я пытаюсь использовать функцию torch.conv2d из Pytorch, но не могу получить результат, который я понимаю ...

Вот простой пример, где ядро ​​(filt) имеет тот же размер, что и вход (im), чтобы объяснить, что я ищу.

import pytorch

filt = torch.rand(3, 3)
im = torch.rand(3, 3)

Я хочу вычислить простую свертку без заполнения ,поэтому результатом должен быть скаляр (т. е. тензор 1x1).

Я пробовал это с conv2d:

# I have to convert image and kernel to 4 dimensions tensors to use conv2d
im_torch = im.reshape((im_height, filt_height, 1, 1))
filt_torch = filt.reshape((filt_height, im_height, 1, 1))
out = torch.nn.functional.conv2d(im_torch, filt_torch, stride=1, padding=0)
print(out)

Но результат не тот, который я ожидал:

tensor([[[[0.6067]], [[0.3564]], [[0.5397]]],
    [[[0.2557]], [[0.0493]], [[0.2562]]],
    [[[0.6067]], [[0.3564]], [[0.5397]]]])

Чтобы дать представление о том, что я хотел бы, я хочу воспроизвести поведение scipy convolve2d:

import scipy.signal
out_scipy = scipy.signal.convolve2d(im.detach().numpy(), filt.detach().numpy(), 'valid')
print(out_scipy)

, которое печатает:

array([[1.195723]], dtype=float32)

Ответы [ 2 ]

2 голосов
/ 13 мая 2019

Тензорная форма вашего ввода и фильтра должна быть:

(batch, dim_ch, width, height)

и НЕ:

(width, height, 1, 1)

, например

import torch
import torch.nn.functional as F
x = torch.randn(1,1,4,4);
y = torch.randn(1,1,4,4);
z = F.conv2d(x,y);

Форма вывода z:

torch.Size([1,1,1,1])
1 голос
/ 13 мая 2019

Хорошо, я не нашел точного ответа на свой вопрос (т.е. как использовать conv2d), но я нашел другой способ сделать это.

Прежде всего, я узнал, что я ищу, это называется valid кросс-корреляция, и это фактически операция, реализуемая классом [Conv2d][1].

Следовательно, мое решение использует класс Conv2d вместо функции conv2d.

import pytorch

img = torch.rand(3, 3)

model = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 3), stride=1, padding=0, bias=False)

res = conv_mdl(img)
print(res.shape)

Который печатает скаляр, который я хотел:

torch.Size([1, 1, 1, 1])

PS: Я также проверил, что результат правильный, а не только размер.

...