Я реализую дерево сегментов в cython и сравниваю его с реализацией python.
Версия cython кажется только в 1,5 раза быстрее, и я хочу сделать ее еще быстрее.
Обе реализации можно считать правильными.
Вот код cython:
# distutils: language = c++
from libcpp.vector cimport vector
cdef struct Result:
int range_sum
int range_min
int range_max
cdef class SegmentTree:
cdef vector[int] nums
cdef vector[Result] tree
def __init__(self, vector[int] nums):
self.nums = nums
self.tree.resize(4 * len(nums)) #just a safe upper bound
self._build(1, 0, len(nums)-1)
cdef Result _build(self, int index, int left, int right):
cdef Result result
if left == right:
value = self.nums[left]
result.range_max, result.range_min, result.range_sum = value, value, value
self.tree[index] = result
return self.tree[index]
else:
mid = (left+right)//2
left_range_result = self._build(index*2, left, mid)
right_range_result = self._build(index*2+1, mid+1, right)
self.tree[index] = self.combine_range_results(left_range_result, right_range_result)
return self.tree[index]
cdef Result range_query(self, int query_i, int query_j):
return self._range_query(query_i, query_j, 0, len(self.nums)-1, 1)
cdef Result _range_query(self, int query_i, int query_j, int current_i, int current_j, int index):
if current_i == query_i and current_j == query_j:
return self.tree[index]
else:
mid = (current_i + current_j)//2
if query_j <= mid:
return self._range_query(query_i, query_j, current_i, mid, index*2)
elif mid < query_i:
return self._range_query(query_i, query_j, mid+1, current_j, index*2+1 )
else:
left_range_result = self._range_query(query_i, mid, current_i, mid, index*2)
right_range_result = self._range_query(mid+1, query_j, mid+1, current_j, index*2+1)
return self.combine_range_results(left_range_result, right_range_result)
cpdef int range_sum(self, int query_i, int query_j):
return self.range_query(query_i, query_j).range_sum
cpdef int range_min(self, int query_i, int query_j):
return self.range_query(query_i, query_j).range_min
cpdef int range_max(self, int query_i, int query_j):
return self.range_query(query_i, query_j).range_max
cpdef void update(self, int i, int new_value):
self._update(i, new_value, 1, 0, len(self.nums)-1)
cdef Result _update(self, int i, int new_value, int index, int left, int right):
if left == right == i:
self.tree[index] = [new_value, new_value, new_value]
return self.tree[index]
if left == right:
return self.tree[index]
mid = (left+right)//2
left_range_result = self._update(i, new_value, index*2, left, mid)
right_range_result = self._update(i, new_value, index*2+1, mid+1, right)
self.tree[index] = self.combine_range_results(left_range_result, right_range_result)
return self.tree[index]
cdef Result combine_range_results(self, Result r1, Result r2):
cdef Result result;
result.range_min = min(r1.range_min, r2.range_min)
result.range_max = max(r1.range_max, r2.range_max)
result.range_sum = r1.range_sum + r2.range_sum
return result
Вот python версия:
class PurePythonSegmentTree:
def __init__(self, nums):
self.nums = nums
self.tree = [0] * (len(nums) * 4)
self._build(1, 0, len(nums) - 1)
def _build(self, index, left, right):
if left == right:
value = self.nums[left]
self.tree[index] = (value, value, value)
return self.tree[index]
else:
mid = (left + right) // 2
left_range_result = self._build(index * 2, left, mid)
right_range_result = self._build(index * 2 + 1, mid + 1, right)
self.tree[index] = self._combine_range_results(
left_range_result, right_range_result)
return self.tree[index]
def range_query(self, query_i, query_j):
return self._range_query(query_i, query_j, 0, len(self.nums) - 1, 1)
def _range_query(self, query_i, query_j, current_i, current_j, index):
if current_i == query_i and current_j == query_j:
return self.tree[index]
else:
mid = (current_i + current_j) // 2
if query_j <= mid:
return self._range_query(query_i, query_j, current_i, mid,
index * 2)
elif mid < query_i:
return self._range_query(query_i, query_j, mid + 1, current_j,
index * 2 + 1)
else:
left_range_result = self._range_query(query_i, mid, current_i,
mid, index * 2)
right_range_result = self._range_query(mid + 1, query_j,
mid + 1, current_j,
index * 2 + 1)
return self._combine_range_results(left_range_result,
right_range_result)
def range_sum(self, query_i, query_j):
return self.range_query(query_i, query_j)[0]
def range_min(self, query_i, query_j):
return self.range_query(query_i, query_j)[1]
def range_max(self, query_i, query_j):
return self.range_query(query_i, query_j)[2]
def _combine_range_results(self, r1, r2):
return (r1[0] + r2[0], min(r1[1], r2[1]), max(r1[2], r2[2]))
Код тестирования :
import pytest
from segment_tree import SegmentTree
def _test_all_ranges(nums, correct_fn, test_fn, threshold=float("inf")):
count = 0
for i in range(len(nums)):
for j in range(i + 1, len(nums)):
if count > threshold:
break
expected = correct_fn(nums[i:j + 1])
actual = test_fn(i, j)
assert actual == expected
count += 1
def test_cython_tree_speed(benchmark):
nums = [i for i in range(1000)]
@benchmark
def foo():
s = SegmentTree(nums)
_test_all_ranges(nums, max, s.range_max, 20)
def test_python_tree_speed(benchmark):
nums = [i for i in range(1000)]
@benchmark
def foo():
s = PurePythonSegmentTree(nums)
_test_all_ranges(nums, max, s.range_max, 20)
Статистика:
-------------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_cython_tree_speed 708.0450 (1.0) 1,534.6150 (1.0) 739.7052 (1.0) 59.9436 (1.0) 717.7565 (1.0) 21.0070 (1.0) 116;200 1,351.8900 (1.0) 1290 1
test_python_tree_speed 1,625.1940 (2.30) 2,676.9020 (1.74) 1,696.8420 (2.29) 135.9121 (2.27) 1,644.7810 (2.29) 79.6613 (3.79) 36;37 589.3300 (0.44) 391 1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Как сделать цитонизированную версию быстрее?