Итак, я внимательно посмотрел на то, что вы сделали до сих пор. Я определил три источника ошибок в вашем коде. Я постараюсь в достаточной степени обратиться к каждому из них здесь.
1. Комплексная арифметика c
PyTorch в настоящее время не поддерживает умножение комплексных чисел (AFAIK). Операция FFT просто возвращает тензор с действительным и мнимым измерением. Вместо использования torch.mul
или оператора *
нам нужно явно кодировать сложное умножение.
(a + ib) * (c + id) = (a *c - b * d) + i (a * d + b *c)
2. Определение свертки
Определение «свертки», часто используемое в литературе CNN, фактически отличается от определения, используемого при обсуждении теоремы о свертке. Я не буду go подробно описывать, но теоретическое определение 1016 * переворачивает ядро перед скольжением и умножением. Вместо этого операция свертки в pytorch, tenorflow, caffe и др. c ... не выполняет это переключение.
Чтобы учесть это, мы можем просто перевернуть g
(как по горизонтали, так и по вертикали) перед применяя БПФ.
3. Позиция привязки
При использовании теоремы свертки точка привязки считается верхним левым углом дополненной g
. Опять же, я не буду go подробно об этом, но именно так работает математика.
Второй и третий пункт может быть легче понять на примере. Предположим, вы использовали следующее g
[1 2 3]
[4 5 6]
[7 8 9]
вместо g_new
, являющегося
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 2 3 0 0]
[0 0 4 5 6 0 0]
[0 0 7 8 9 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
, на самом деле оно должно быть
[5 4 0 0 0 0 6]
[2 1 0 0 0 0 3]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[8 7 0 0 0 0 9]
, где мы переворачиваем ядро вертикально и горизонтально, затем примените круговое смещение, чтобы центр ядра находился в верхнем левом углу.
В итоге я переписал большую часть вашего кода и немного его обобщил. Самая сложная операция - это правильное определение g_new
. Я решил использовать meshgrid и modulo arithmeti c, чтобы одновременно переворачивать и сдвигать индексы. Если что-то здесь не имеет смысла для вас, пожалуйста, оставьте комментарий, и я постараюсь уточнить.
import torch
import torch.nn.functional as F
def conv2d_pyt(f, g):
assert len(f.size()) == 2
assert len(g.size()) == 2
f_new = f.unsqueeze(0).unsqueeze(0)
g_new = g.unsqueeze(0).unsqueeze(0)
pad_y = (g.size(0) - 1) // 2
pad_x = (g.size(1) - 1) // 2
fcg = F.conv2d(f_new, g_new, bias=None, padding=(pad_y, pad_x))
return fcg[0, 0, :, :]
def conv2d_fft(f, g):
assert len(f.size()) == 2
assert len(g.size()) == 2
# in general not necessary that inputs are odd shaped but makes life easier
assert f.size(0) % 2 == 1
assert f.size(1) % 2 == 1
assert g.size(0) % 2 == 1
assert g.size(1) % 2 == 1
size_y = f.size(0) + g.size(0) - 1
size_x = f.size(1) + g.size(1) - 1
f_new = torch.zeros((size_y, size_x))
g_new = torch.zeros((size_y, size_x))
# copy f to center
f_pad_y = (f_new.size(0) - f.size(0)) // 2
f_pad_x = (f_new.size(1) - f.size(1)) // 2
f_new[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x] = f
# anchor of g is 0,0 (flip g and wrap circular)
g_center_y = g.size(0) // 2
g_center_x = g.size(1) // 2
g_y, g_x = torch.meshgrid(torch.arange(g.size(0)), torch.arange(g.size(1)))
g_new_y = (g_y.flip(0) - g_center_y) % g_new.size(0)
g_new_x = (g_x.flip(1) - g_center_x) % g_new.size(1)
g_new[g_new_y, g_new_x] = g[g_y, g_x]
# take fft of both f and g
F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
# complex multiply
FxG_real = F_f[:, :, 0] * F_g[:, :, 0] - F_f[:, :, 1] * F_g[:, :, 1]
FxG_imag = F_f[:, :, 0] * F_g[:, :, 1] + F_f[:, :, 1] * F_g[:, :, 0]
FxG = torch.stack([FxG_real, FxG_imag], dim=2)
# inverse fft
fcg = torch.irfft(FxG, signal_ndim=2, onesided=False)
# crop center before returning
return fcg[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x]
# calculate f*g
f = torch.randn(11, 7)
g = torch.randn(5, 3)
fcg_pyt = conv2d_pyt(f, g)
fcg_fft = conv2d_fft(f, g)
avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()
print('Average difference:', avg_diff)
Что дает мне
Average difference: 4.6866085767760524e-07
Это очень близко к нулю. Причина, по которой мы не получаем ровно ноль, заключается просто в ошибках с плавающей запятой.