Оптимизация Pytorch Hartley Pooling - PullRequest
0 голосов
/ 10 марта 2020

Я реализовал функцию ND Hartley Pooling (см. в этом документе для получения дополнительной информации) для Pytorch, используя cupy и sigpy, как показано ниже.

Хотя эта функция Кажется, работает нормально, есть ли способ дальнейшей оптимизации этого кода, чтобы сделать его быстрее?

Меня интересуют до четырехмерных данных, поэтому я попытался использовать N-размерность везде, где это возможно. К сожалению, Pytorch, похоже, не имеет 4-мерной функции FFT, поэтому вместо этого я использовал другие библиотеки GPU для выполнения FFT на GPU.

import torch
import torch.nn as nn
from torch.autograd import Function
import math
import operator
import cupy as cp
import sigpy as sp



def _spectral_crop(array, array_shape, bounding_shape):
    start = tuple(map(lambda a, da: (a-da)//2, array_shape, bounding_shape))
    end = tuple(map(operator.add, start, bounding_shape))
    slices = tuple(map(slice, start, end))
    return array[slices]

def _spectral_pad(array, array_shape, bounding_shape):
    out = cp.zeros(bounding_shape)
    start = tuple(map(lambda a, da: (a-da)//2, bounding_shape, array_shape))
    end = tuple(map(operator.add, start, array_shape))
    slices = tuple(map(slice, start, end))
    out[slices] = array
    return out

def DiscreteHartleyTransform(input):
    N = input.ndim
    axes_n = np.arange(2,N)
    fft = sp.fft(input, axes=axes_n)
    H = fft.real - fft.imag
    return H

def CropForward(input, return_shape):

    output_shape = np.zeros(input.ndim).astype(int)
    output_shape[0] = input.shape[0]
    output_shape[1] = input.shape[1]
    output_shape[2:] = np.asarray(return_shape).astype(int)

    dht = DiscreteHartleyTransform(input)
    dht = _spectral_crop(dht, dht.shape, output_shape)
    dht = DiscreteHartleyTransform(dht)

    return dht

def PadBackward(grad_output, input_shape):
    dht = DiscreteHartleyTransform(grad_output)
    dht = _spectral_pad(dht, dht.shape, input_shape)
    dht = DiscreteHartleyTransform(dht)

    return dht


class SpectralPoolingFunction(Function):
    @staticmethod
    def forward(ctx, input, return_shape):
         input = sp.from_pytorch(input)
         ctx.input_shape = input.shape
         output = CropForward(input, return_shape)
         output = sp.to_pytorch(output)
         output = output.float()
         return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = sp.from_pytorch(grad_output)
        grad_input = PadBackward(grad_output, ctx.input_shape)
        grad_input = sp.to_pytorch(grad_input)
        grad_input = grad_input.float()
       return grad_input, None, None


class SpectralPoolNd(nn.Module):
    def __init__(self, return_shape):
        super(SpectralPoolNd, self).__init__()
        self.return_shape = return_shape

    def forward(self, input):
        return SpectralPoolingFunction.apply(input, self.return_shape)
...