Я реализовал функцию ND Hartley Pooling (см. в этом документе для получения дополнительной информации) для Pytorch, используя cupy
и sigpy
, как показано ниже.
Хотя эта функция Кажется, работает нормально, есть ли способ дальнейшей оптимизации этого кода, чтобы сделать его быстрее?
Меня интересуют до четырехмерных данных, поэтому я попытался использовать N-размерность везде, где это возможно. К сожалению, Pytorch, похоже, не имеет 4-мерной функции FFT, поэтому вместо этого я использовал другие библиотеки GPU для выполнения FFT на GPU.
import torch
import torch.nn as nn
from torch.autograd import Function
import math
import operator
import cupy as cp
import sigpy as sp
def _spectral_crop(array, array_shape, bounding_shape):
start = tuple(map(lambda a, da: (a-da)//2, array_shape, bounding_shape))
end = tuple(map(operator.add, start, bounding_shape))
slices = tuple(map(slice, start, end))
return array[slices]
def _spectral_pad(array, array_shape, bounding_shape):
out = cp.zeros(bounding_shape)
start = tuple(map(lambda a, da: (a-da)//2, bounding_shape, array_shape))
end = tuple(map(operator.add, start, array_shape))
slices = tuple(map(slice, start, end))
out[slices] = array
return out
def DiscreteHartleyTransform(input):
N = input.ndim
axes_n = np.arange(2,N)
fft = sp.fft(input, axes=axes_n)
H = fft.real - fft.imag
return H
def CropForward(input, return_shape):
output_shape = np.zeros(input.ndim).astype(int)
output_shape[0] = input.shape[0]
output_shape[1] = input.shape[1]
output_shape[2:] = np.asarray(return_shape).astype(int)
dht = DiscreteHartleyTransform(input)
dht = _spectral_crop(dht, dht.shape, output_shape)
dht = DiscreteHartleyTransform(dht)
return dht
def PadBackward(grad_output, input_shape):
dht = DiscreteHartleyTransform(grad_output)
dht = _spectral_pad(dht, dht.shape, input_shape)
dht = DiscreteHartleyTransform(dht)
return dht
class SpectralPoolingFunction(Function):
@staticmethod
def forward(ctx, input, return_shape):
input = sp.from_pytorch(input)
ctx.input_shape = input.shape
output = CropForward(input, return_shape)
output = sp.to_pytorch(output)
output = output.float()
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = sp.from_pytorch(grad_output)
grad_input = PadBackward(grad_output, ctx.input_shape)
grad_input = sp.to_pytorch(grad_input)
grad_input = grad_input.float()
return grad_input, None, None
class SpectralPoolNd(nn.Module):
def __init__(self, return_shape):
super(SpectralPoolNd, self).__init__()
self.return_shape = return_shape
def forward(self, input):
return SpectralPoolingFunction.apply(input, self.return_shape)