Cython: как ускорить рекурсивные функции? - PullRequest
1 голос
/ 20 июня 2020

Я реализую дерево сегментов в cython и сравниваю его с реализацией python.

Версия cython кажется только в 1,5 раза быстрее, и я хочу сделать ее еще быстрее.

Обе реализации можно считать правильными.

Вот код cython:

# distutils: language = c++
from libcpp.vector cimport vector

cdef struct Result:
    int range_sum  
    int range_min 
    int range_max



cdef class SegmentTree:
    cdef vector[int] nums
    cdef vector[Result] tree 

    def __init__(self, vector[int] nums):
        self.nums = nums
        self.tree.resize(4 * len(nums)) #just a safe upper bound 
        self._build(1, 0, len(nums)-1)

    cdef Result _build(self, int index, int left, int right):
        cdef Result result

        if left == right:
            value = self.nums[left]
            result.range_max, result.range_min, result.range_sum = value, value, value 
            self.tree[index] = result
            return self.tree[index]
        else:
            mid = (left+right)//2
            left_range_result = self._build(index*2, left, mid)
            right_range_result = self._build(index*2+1, mid+1, right)
            self.tree[index] = self.combine_range_results(left_range_result, right_range_result)
            return self.tree[index]

    cdef Result range_query(self, int query_i, int query_j):
        return self._range_query(query_i, query_j, 0, len(self.nums)-1, 1)

    cdef Result _range_query(self, int query_i, int query_j, int current_i, int current_j, int index):
        if current_i == query_i and current_j == query_j:
            return self.tree[index]
        else:
            mid = (current_i + current_j)//2 
            if query_j <= mid:
                return self._range_query(query_i, query_j, current_i, mid, index*2)
            elif mid < query_i:
                return self._range_query(query_i, query_j, mid+1, current_j, index*2+1 )  
            else:
                left_range_result = self._range_query(query_i, mid, current_i, mid, index*2)
                right_range_result = self._range_query(mid+1, query_j, mid+1, current_j, index*2+1)
                return self.combine_range_results(left_range_result, right_range_result)


    cpdef int range_sum(self, int query_i, int query_j):
        return self.range_query(query_i, query_j).range_sum 
    cpdef int range_min(self, int query_i, int query_j):
        return self.range_query(query_i, query_j).range_min
    cpdef int range_max(self, int query_i, int query_j):
        return self.range_query(query_i, query_j).range_max

    cpdef void  update(self, int i, int new_value):
        self._update(i, new_value, 1, 0, len(self.nums)-1)

    cdef Result _update(self, int i, int new_value, int index, int left, int right):
        if left == right == i:
            self.tree[index] = [new_value, new_value, new_value]
            return self.tree[index]
        if left == right:
            return self.tree[index]
        mid = (left+right)//2 
        left_range_result = self._update(i, new_value, index*2, left, mid)
        right_range_result = self._update(i, new_value, index*2+1, mid+1, right)
        self.tree[index] = self.combine_range_results(left_range_result, right_range_result)
        return self.tree[index]

    cdef Result combine_range_results(self, Result r1, Result r2):
        cdef Result result;
        result.range_min = min(r1.range_min, r2.range_min)
        result.range_max = max(r1.range_max, r2.range_max)
        result.range_sum = r1.range_sum + r2.range_sum
        return result 
        

Вот python версия:




class PurePythonSegmentTree:
    def __init__(self, nums):
        self.nums = nums
        self.tree = [0] * (len(nums) * 4)
        self._build(1, 0, len(nums) - 1)

    def _build(self, index, left, right):
        if left == right:
            value = self.nums[left]
            self.tree[index] = (value, value, value)
            return self.tree[index]
        else:
            mid = (left + right) // 2
            left_range_result = self._build(index * 2, left, mid)
            right_range_result = self._build(index * 2 + 1, mid + 1, right)
            self.tree[index] = self._combine_range_results(
                left_range_result, right_range_result)
            return self.tree[index]

    def range_query(self, query_i, query_j):
        return self._range_query(query_i, query_j, 0, len(self.nums) - 1, 1)

    def _range_query(self, query_i, query_j, current_i, current_j, index):
        if current_i == query_i and current_j == query_j:
            return self.tree[index]
        else:
            mid = (current_i + current_j) // 2
            if query_j <= mid:
                return self._range_query(query_i, query_j, current_i, mid,
                                         index * 2)
            elif mid < query_i:
                return self._range_query(query_i, query_j, mid + 1, current_j,
                                         index * 2 + 1)
            else:
                left_range_result = self._range_query(query_i, mid, current_i,
                                                      mid, index * 2)
                right_range_result = self._range_query(mid + 1, query_j,
                                                       mid + 1, current_j,
                                                       index * 2 + 1)
                return self._combine_range_results(left_range_result,
                                                   right_range_result)

    def range_sum(self, query_i, query_j):
        return self.range_query(query_i, query_j)[0]

    def range_min(self, query_i, query_j):
        return self.range_query(query_i, query_j)[1]

    def range_max(self, query_i, query_j):
        return self.range_query(query_i, query_j)[2]

    def _combine_range_results(self, r1, r2):
        return (r1[0] + r2[0], min(r1[1], r2[1]), max(r1[2], r2[2]))


Код тестирования :

import pytest
from segment_tree import SegmentTree

def _test_all_ranges(nums, correct_fn, test_fn, threshold=float("inf")):
    count = 0
    for i in range(len(nums)):
        for j in range(i + 1, len(nums)):
            if count > threshold:
                break
            expected = correct_fn(nums[i:j + 1])
            actual = test_fn(i, j)
            assert actual == expected
            count += 1


def test_cython_tree_speed(benchmark):
    nums = [i for i in range(1000)]

    @benchmark
    def foo():
        s = SegmentTree(nums)
        _test_all_ranges(nums, max, s.range_max, 20)


def test_python_tree_speed(benchmark):
    nums = [i for i in range(1000)]

    @benchmark
    def foo():
        s = PurePythonSegmentTree(nums)
        _test_all_ranges(nums, max, s.range_max, 20)

Статистика:

-------------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------------
Name (time in us)                 Min                   Max                  Mean              StdDev                Median                IQR            Outliers         OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_cython_tree_speed       708.0450 (1.0)      1,534.6150 (1.0)        739.7052 (1.0)       59.9436 (1.0)        717.7565 (1.0)      21.0070 (1.0)       116;200  1,351.8900 (1.0)        1290           1
test_python_tree_speed     1,625.1940 (2.30)     2,676.9020 (1.74)     1,696.8420 (2.29)     135.9121 (2.27)     1,644.7810 (2.29)     79.6613 (3.79)        36;37    589.3300 (0.44)        391           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Как сделать цитонизированную версию быстрее?

1 Ответ

2 голосов
/ 20 июня 2020

При попытке оптимизировать код Cython первым шагом является сборка с аннотациями (см., Например, эту часть Cython-документации ), т.е.

 cython -a xxx.pyx

или аналогичные. Он генерирует html, в котором можно увидеть, какие части кода используют Python -функции.

В вашем случае можно увидеть, что mid = (current_i + current_j)//2 является проблемой.

Он генерирует следующий код C:

  /*else*/ {
    __pyx_t_3 = __Pyx_PyInt_From_long(__Pyx_div_long((__pyx_v_current_i + __pyx_v_current_j), 2)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 42, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_3);
    __pyx_v_mid = __pyx_t_3;
    __pyx_t_3 = 0;

Т.е. mid является целым числом Python (из-за __Pyx_PyInt_From_long), и все операции с ним приведут к большему преобразованию в Python -целые и медленные операции.

Make mid cdef int. Изучите другие желтые линии (взаимодействие с Python) в аннотированном коде.

...