Ошибка сегментации Cython при создании объекта в стеке - PullRequest
0 голосов
/ 10 апреля 2020

У меня есть следующий класс C ++, который я хочу обернуть с помощью Cython. Класс содержит больше методов, но я включил только важные.

#include <cstdint>
#include <cassert>
#include <vector>

template<class T>
class HkCluster {
private:
    std::vector<T> matrix;
    std::vector<T> labels;
    int n_labels{};
    int cols{};
    int rows{};
public:
    HkCluster() = default;

    HkCluster(T *matrix, int cols, int rows) {
        this->cols = cols;
        this->rows = rows;
        this->n_labels = cols * rows / 2;
        this->labels = std::vector<T>(cols * rows, 0);
        this->matrix = std::vector<T>(matrix, matrix + cols * rows);
    }

    void setMatrix(T * matrix){
        this->matrix = std::vector<T>(matrix, matrix + this->cols * this->rows);
    }

    void setCols(int cols) {
        HkCluster::cols = cols;
    }

    void setRows(int rows) {
        HkCluster::rows = rows;
    }
};

Вот как я обертываю класс с помощью Cython. Этот код работает, но вы можете видеть, что я создал объект в куче, которая не является оптимальной

import numpy as np

cdef extern from "hk.cpp":
    cdef cppclass HkCluster[T]:
        HkCluster();
        HkCluster(T *matrix, int cols, int rows);
        T hk_cluster(T * ret);
        void setMatrix(T * matrix);
        void setCols(int cols);
        void setRows(int rows);

def hk(a not None):
    arr = a.copy()
    if not arr.flags['C_CONTIGUOUS']:
        arr = np.ascontiguousarray(arr)
    if arr.dtype != np.int32:
        arr = np.cast[np.int32](arr)
    cdef int[:, ::1] a_mem_view = arr
    cdef HkCluster[int] *cluster = new HkCluster[int](&(a_mem_view[0, 0]), a.shape[1], a.shape[0])
    cluster.setMatrix(&(a_mem_view[0, 0]))
    cluster.setRows(a.shape[0])
    cluster.setCols(a.shape[1])
    count = cluster.hk_cluster(&(a_mem_view[0, 0]))
    del cluster
    return count, arr

Это моя попытка создать объект в стеке, что приводит к ошибке сегментации. Я прочитал документацию Cython от https://cython.readthedocs.io/en/latest/src/userguide/wrapping_CPlusPlus.html и код должен работать. Что я делаю не так?

import numpy as np

cdef extern from "hk.cpp":
    cdef cppclass HkCluster[T]:
        HkCluster();
        HkCluster(T *matrix, int cols, int rows);
        T hk_cluster(T * ret);
        void setMatrix(T * matrix);
        void setCols(int cols);
        void setRows(int rows);

def hk(a not None):
    arr = a.copy()
    if not arr.flags['C_CONTIGUOUS']:
        arr = np.ascontiguousarray(arr)
    if arr.dtype != np.int32:
        arr = np.cast[np.int32](arr)
    cdef int[:, ::1] a_mem_view = arr
    cdef HkCluster[int] cluster
    cluster.setMatrix(&(a_mem_view[0, 0]))
    cluster.setRows(a.shape[0])
    cluster.setCols(a.shape[1])
    count = cluster.hk_cluster(&(a_mem_view[0, 0]))
    return count, arr
...