Проверьте теорему свертки, используя pytorch - PullRequest
3 голосов
/ 06 марта 2020

В основном эта теорема сформулирована следующим образом:

F (f * g) = F (f) xF (g)

Я знаю эту теорему, но я просто не может воспроизвести результат с помощью pytorch.

Ниже приведен воспроизводимый код:

import torch
import torch.nn.functional as F

# calculate f*g
f = torch.ones((1,1,5,5))
g = torch.tensor(list(range(9))).view(1,1,3,3).float()
conv = F.conv2d(f, g, bias=None, padding=2)

# calculate F(f*g)
F_fg = torch.rfft(conv, signal_ndim=2, onesided=False)

# calculate F x G
f = f.squeeze()
g = g.squeeze()

# need to pad into at least [w1+w2-1, h1+h2-1], which is 7 in our case.
size = f.size(0) + g.size(0) - 1 

f_new = torch.zeros((7,7))
g_new = torch.zeros((7,7))

f_new[1:6,1:6] = f
g_new[2:5,2:5] = g

F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
FxG = torch.mul(F_f, F_g)

print(FxG - F_fg)

вот результат для print (FxG - F_fg)

tensor([[[[[ 0.0000e+00,  0.0000e+00],
       [ 4.1426e+02,  1.7270e+02],
       [-3.6546e+01,  4.7600e+01],
       [-1.0216e+01, -4.1198e+01],
       [-1.0216e+01, -2.0223e+00],
       [-3.6546e+01, -6.2804e+01],
       [ 4.1426e+02, -1.1427e+02]],

      ...

      [[ 4.1063e+02, -2.2347e+02],
       [-7.6294e-06,  2.2817e+01],
       [-1.9024e+01, -9.0105e+00],
       [ 7.1708e+00, -4.1027e+00],
       [-2.6739e+00, -1.1121e+01],
       [ 8.8471e+00,  7.1710e+00],
       [ 4.2528e+01,  9.7559e+01]]]]])

и вы видите, что разница не всегда равна 0.

Может кто-нибудь сказать мне, почему и как это сделать правильно?

Спасибо

1 Ответ

3 голосов
/ 08 марта 2020

Итак, я внимательно посмотрел на то, что вы сделали до сих пор. Я определил три источника ошибок в вашем коде. Я постараюсь в достаточной степени обратиться к каждому из них здесь.

1. Комплексная арифметика c

PyTorch в настоящее время не поддерживает умножение комплексных чисел (AFAIK). Операция FFT просто возвращает тензор с действительным и мнимым измерением. Вместо использования torch.mul или оператора * нам нужно явно кодировать сложное умножение.

(a + ib) * (c + id) = (a *c - b * d) + i (a * d + b *c)

2. Определение свертки

Определение «свертки», часто используемое в литературе CNN, фактически отличается от определения, используемого при обсуждении теоремы о свертке. Я не буду go подробно описывать, но теоретическое определение 1016 * переворачивает ядро ​​перед скольжением и умножением. Вместо этого операция свертки в pytorch, tenorflow, caffe и др. c ... не выполняет это переключение.

Чтобы учесть это, мы можем просто перевернуть g (как по горизонтали, так и по вертикали) перед применяя БПФ.

3. Позиция привязки

При использовании теоремы свертки точка привязки считается верхним левым углом дополненной g. Опять же, я не буду go подробно об этом, но именно так работает математика.


Второй и третий пункт может быть легче понять на примере. Предположим, вы использовали следующее g

[1 2 3]
[4 5 6]
[7 8 9]

вместо g_new, являющегося

[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 2 3 0 0]
[0 0 4 5 6 0 0]
[0 0 7 8 9 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]

, на самом деле оно должно быть

[5 4 0 0 0 0 6]
[2 1 0 0 0 0 3]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[8 7 0 0 0 0 9]

, где мы переворачиваем ядро вертикально и горизонтально, затем примените круговое смещение, чтобы центр ядра находился в верхнем левом углу.


В итоге я переписал большую часть вашего кода и немного его обобщил. Самая сложная операция - это правильное определение g_new. Я решил использовать meshgrid и modulo arithmeti c, чтобы одновременно переворачивать и сдвигать индексы. Если что-то здесь не имеет смысла для вас, пожалуйста, оставьте комментарий, и я постараюсь уточнить.

import torch
import torch.nn.functional as F

def conv2d_pyt(f, g):
    assert len(f.size()) == 2
    assert len(g.size()) == 2

    f_new = f.unsqueeze(0).unsqueeze(0)
    g_new = g.unsqueeze(0).unsqueeze(0)

    pad_y = (g.size(0) - 1) // 2
    pad_x = (g.size(1) - 1) // 2

    fcg = F.conv2d(f_new, g_new, bias=None, padding=(pad_y, pad_x))
    return fcg[0, 0, :, :]

def conv2d_fft(f, g):
    assert len(f.size()) == 2
    assert len(g.size()) == 2

    # in general not necessary that inputs are odd shaped but makes life easier
    assert f.size(0) % 2 == 1
    assert f.size(1) % 2 == 1
    assert g.size(0) % 2 == 1
    assert g.size(1) % 2 == 1

    size_y = f.size(0) + g.size(0) - 1
    size_x = f.size(1) + g.size(1) - 1

    f_new = torch.zeros((size_y, size_x))
    g_new = torch.zeros((size_y, size_x))

    # copy f to center
    f_pad_y = (f_new.size(0) - f.size(0)) // 2
    f_pad_x = (f_new.size(1) - f.size(1)) // 2
    f_new[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x] = f

    # anchor of g is 0,0 (flip g and wrap circular)
    g_center_y = g.size(0) // 2
    g_center_x = g.size(1) // 2
    g_y, g_x = torch.meshgrid(torch.arange(g.size(0)), torch.arange(g.size(1)))
    g_new_y = (g_y.flip(0) - g_center_y) % g_new.size(0)
    g_new_x = (g_x.flip(1) - g_center_x) % g_new.size(1)
    g_new[g_new_y, g_new_x] = g[g_y, g_x]

    # take fft of both f and g
    F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
    F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)

    # complex multiply
    FxG_real = F_f[:, :, 0] * F_g[:, :, 0] - F_f[:, :, 1] * F_g[:, :, 1]
    FxG_imag = F_f[:, :, 0] * F_g[:, :, 1] + F_f[:, :, 1] * F_g[:, :, 0]
    FxG = torch.stack([FxG_real, FxG_imag], dim=2)

    # inverse fft
    fcg = torch.irfft(FxG, signal_ndim=2, onesided=False)

    # crop center before returning
    return fcg[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x]


# calculate f*g
f = torch.randn(11, 7)
g = torch.randn(5, 3)

fcg_pyt = conv2d_pyt(f, g)
fcg_fft = conv2d_fft(f, g)

avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()

print('Average difference:', avg_diff)

Что дает мне

Average difference: 4.6866085767760524e-07

Это очень близко к нулю. Причина, по которой мы не получаем ровно ноль, заключается просто в ошибках с плавающей запятой.

...