Более быстрая альтернатива для вычисления свертки в частотной области в факеле - PullRequest
0 голосов
/ 18 марта 2019

Я реализую пользовательскую свертку в torch, которая преобразует изображение в частотную область, используя FFT, вычисляет произведение между ядром и изображением, а затем вычисляет обратное FFT.Хотя это работает, я заметил, что вычисление продукта довольно медленное.Есть ли способ оптимизировать его?

Я добавил timer ко всему, чтобы посмотреть, как это происходит, и получил следующие результаты (тестирование на cpu):

squeezes - done in 4.17232513E-05s
real - done in 1.67846680E-04s
im - done in 7.53402710E-05s
stack - done in 8.36849213E-05s
sum - done in 3.96490097E-04s
bias - done in 1.64508820E-05s

Вот моя реализация:

Обратите внимание, что здесь я не вычисляю FFT и его инверсию.torch реализация этих операций довольно быстрая.

def fconv2d(input, kernel, bias=None):
    # Computes the convolution in the frequency domain given
    # an input of shape (B, Cin, H, W) and kernel of shape (Cout, Cin, H, W).
    # Expects input and kernel already in frequency domain!

    with timer('squeezes'):
        kernel = kernel.unsqueeze(0)
        # Expand kernel to (B, Cout, Cin, H, W)
        # Expand input to (B, Cout, Cin, H, W)
        input = input.unsqueeze(1)
    # Compute the multiplication
    # (a+bj)*(c+dj) = (ac-bd)+(ad+bc)j
    with timer('real'):
        real = input[..., 0] * kernel[..., 0] - \
               input[..., 1] * kernel[..., 1]
    with timer('im'):
        im = input[..., 0] * kernel[..., 1] + \
             input[..., 1] * kernel[..., 0]
    # Stack both channels and sum-reduce the input channels dimension
    with timer('stack'):
        out = torch.stack([real, im], -1)

    with timer('sum'):
        out = out.sum(dim=-4)
    # Add bias
    with timer('bias'):
        if bias is not None:
            bias = bias.expand(1, 1, 1, bias.shape[0]).permute(0, 3, 1, 2)
            out += bias
    return out
...