pybind11: ссылка на библиотеку extern c ++ - PullRequest
0 голосов
/ 20 октября 2019

Я новичок в смешанном программировании на Python-C ++. Я пытаюсь построить модуль, в котором вызывается быстрое преобразование Фурье. Поэтому я связываю модуль с библиотекой FFTW C ++. Модуль успешно собран, но преобразование FFT не может работать правильно, потому что план fftw не был создан.

Мой код прикреплен. ОС, которую я использую, - это подсистема Ubuntu в Windows 10. fftw устанавливается с помощью sudo apt-get install libfftw3-dev Спасибо за помощь.

  1. myfft.hpp
#include <stdlib.h>
#include <math.h>
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <float.h>
#include <fftw3.h>

typedef struct{
    fftw_plan dctIn;
    fftw_plan dctOut;
    double *kernel;
    double *workspace;
}poisson_solver;

double *create_negative_laplace_kernel2d(int n1, int n2){
    double *kernel=(double *)calloc(n1*n2,sizeof(double));
    for(int i=0;i<n2;i++){
        for(int j=0;j<n1;j++){
            double x=M_PI*j/(n1*1.0);
            double y=M_PI*i/(n2*1.0);    
            double negativeLaplacian=2*n1*n1*(1-cos(x))+2*n2*n2*(1-cos(y));
            kernel[i*n1+j]=negativeLaplacian; 
        }
    }
    return kernel;
}

poisson_solver create_poisson_solver_workspace2d(int n1, int n2){
    poisson_solver fftps;
    fftps.workspace=(double *)calloc(n1*n2,sizeof(double));
    if(fftps.workspace == NULL){
        printf("failed to alloc memory for workspace\n");
    }
    fftps.kernel=create_negative_laplace_kernel2d(n1,n2);
    printf("kernel: %f\n", fftps.kernel[1]);

    fftps.dctIn=fftw_plan_r2r_2d(n2, n1, fftps.workspace, fftps.workspace,
                                 FFTW_REDFT10, FFTW_REDFT10,
                                 FFTW_MEASURE);
    fftps.dctOut=fftw_plan_r2r_2d(n2, n1, fftps.workspace, fftps.workspace,
                                  FFTW_REDFT01, FFTW_REDFT01,
                                  FFTW_MEASURE);  
    if(fftps.dctIn == NULL){
        printf("failed to construct fftw plan\n");
        exit(1);
    }
    return fftps;
}

void destroy_poisson_solver(poisson_solver fftps){
    free(fftps.kernel);
    free(fftps.workspace);
    fftw_destroy_plan(fftps.dctIn);
    fftw_destroy_plan(fftps.dctOut);
}

void myFFT(double* mu, double *nu, int n1, int n2){
    poisson_solver fftps = create_poisson_solver_workspace2d(n1, n2);

    int pcount = n1 * n2;
    for(int i=0; i<pcount; i++){
        fftps.workspace[i] = mu[i];
    }
    fftw_execute(fftps.dctIn);
    fftps.workspace[0]=0;
    for(int i=1;i<pcount;i++){
        fftps.workspace[i] /= 4*pcount*fftps.kernel[i];      
    }
    fftw_execute(fftps.dctOut);
    for(int i=0;i<pcount;i++){
        nu[i] = fftps.workspace[i];
    }        
}

binding.cpp
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

#include "myfft.hpp"

namespace py = pybind11;

py::array_t<double> FT(py::array_t<double, py::array::c_style | py::array::forcecast> mu)
{
    py::buffer_info buf1 = mu.request();

    if (buf1.ndim != 2)
        throw std::runtime_error("Number of dimensions must be two.");

    int n1 = buf1.shape[0];
    int n2 = buf1.shape[1];
    printf("Shape n1:%d, n2:%d\n", n1, n2);
    double *ptr1 = (double *) buf1.ptr;

    auto nu = py::array_t<double>({n1, n2});

    py::buffer_info buf3 = nu.request();


    double *ptr3 = (double *) buf3.ptr;

    myFFT(ptr1, ptr3, n1, n2);
    return nu;
}

PYBIND11_MODULE(myFT, m){
    m.doc() = "";

    m.def("FT", &FT, py::return_value_policy::reference);
}
CMakeLists.txt
cmake_minimum_required(VERSION 3.5)

set(PYBIND_CPP_STANDARD -std=c++14)

project(myFT)

add_subdirectory(lib/pybind11)

set(SOURCE_DIR "csrc")

include_directories(${SOURCE_DIR})

set(SOURCES "${SOURCE_DIR}/binding.cpp")

pybind11_add_module(myFT ${SOURCES})

#add_library(myFT MODULE ${SOURCES})

target_link_libraries(myFT PRIVATE pybind11::module fftw3)
setup.py
import os
import re
import sys
import platform
import subprocess

from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion

class CMakeExtension(Extension):
    def __init__(self, name, sourcedir=''):
        Extension.__init__(self, name, sources=[])
        self.sourcedir = os.path.abspath(sourcedir)


class CMakeBuild(build_ext):
    def run(self):
        try:
            out = subprocess.check_output(['cmake', '--version'])
        except OSError:
            raise RuntimeError("CMake must be installed to build the following extensions: " +
                    ", ".join(e.name for e in self.extensions))

        if platform.system() == "Windows":
            cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
            if cmake_version < '3.1.0':
                raise RuntimeError("CMake >= 3.1.0 is required on Windows")

        for ext in self.extensions:
            self.build_extension(ext)

    def build_extension(self, ext):
        extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
        cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
                '-DPYTHON_EXECUTABLE=' + sys.executable]

        cfg = 'Debug' if self.debug else 'Release'
        build_args = ['--config', cfg]

        if platform.system() == "Windows":
            cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
            if sys.maxsize > 2**32:
                cmake_args += ['-A', 'x64']
            build_args += ['--', '/m']
        else:
            cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
            build_args += ['--', '-j2']

        env = os.environ.copy()
        env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''),
                                                                self.distribution.get_version())
        if not os.path.exists(self.build_temp):
            os.makedirs(self.build_temp)
        subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env)
        subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
        print()

setup(
    name='myFT',
    version='0.0.1',
    author='',
    author_email='',
    description='',
    long_description='',
    ext_modules=[CMakeExtension('myFT')],
    cmdclass=dict(build_ext=CMakeBuild),
    zip_safe=False,
)
...