Я реализую пользовательскую свертку в torch
, которая преобразует изображение в частотную область, используя FFT
, вычисляет произведение между ядром и изображением, а затем вычисляет обратное FFT
.Хотя это работает, я заметил, что вычисление продукта довольно медленное.Есть ли способ оптимизировать его?
Я добавил timer
ко всему, чтобы посмотреть, как это происходит, и получил следующие результаты (тестирование на cpu
):
squeezes - done in 4.17232513E-05s
real - done in 1.67846680E-04s
im - done in 7.53402710E-05s
stack - done in 8.36849213E-05s
sum - done in 3.96490097E-04s
bias - done in 1.64508820E-05s
Вот моя реализация:
Обратите внимание, что здесь я не вычисляю FFT
и его инверсию.torch
реализация этих операций довольно быстрая.
def fconv2d(input, kernel, bias=None):
# Computes the convolution in the frequency domain given
# an input of shape (B, Cin, H, W) and kernel of shape (Cout, Cin, H, W).
# Expects input and kernel already in frequency domain!
with timer('squeezes'):
kernel = kernel.unsqueeze(0)
# Expand kernel to (B, Cout, Cin, H, W)
# Expand input to (B, Cout, Cin, H, W)
input = input.unsqueeze(1)
# Compute the multiplication
# (a+bj)*(c+dj) = (ac-bd)+(ad+bc)j
with timer('real'):
real = input[..., 0] * kernel[..., 0] - \
input[..., 1] * kernel[..., 1]
with timer('im'):
im = input[..., 0] * kernel[..., 1] + \
input[..., 1] * kernel[..., 0]
# Stack both channels and sum-reduce the input channels dimension
with timer('stack'):
out = torch.stack([real, im], -1)
with timer('sum'):
out = out.sum(dim=-4)
# Add bias
with timer('bias'):
if bias is not None:
bias = bias.expand(1, 1, 1, bias.shape[0]).permute(0, 3, 1, 2)
out += bias
return out