Как указано в scipy.spatial.KDTree()
:
При больших размерах (20 уже большие) не ожидайте, что это будет работать значительно быстрее, чем грубая сила.Высокоразмерные запросы ближайших соседей - существенная открытая проблема в информатике.
(эта заметка присутствует и в scipy.spatial.cKDTree()
, хотя, вероятно, это копирование-вставкаошибка документации).
Я позволил себе переписать ваш код с соответствующими функциями, чтобы я мог запустить некоторые автоматизированные тесты (на основе этого шаблона ).Я также включил реализацию Numba методом «грубой силы»:
import numpy as np
import scipy as sp
import numba as nb
import scipy.spatial
SCALE = 400.0
RADIUS = 50.0
def find_nn_np(points, radius=RADIUS, p=2):
n_points, n_dim = points.shape
result = np.empty(n_points, dtype=object)
for i in range(n_points):
result[i] = np.where(np.sum(np.abs(points - points[i:i + 1, :]) ** p, axis=1) < radius ** p)[0].tolist()
return result
def find_nn_kd_tree(points, radius=RADIUS):
tree = sp.spatial.KDTree(points)
return tree.query_ball_point(points, radius)
def find_nn_kd_tree_cy(points, radius=RADIUS):
tree = sp.spatial.cKDTree(points)
return tree.query_ball_point(points, radius)
@nb.jit
def neighbors_indexes_jit(radius, center, points, p=2):
n_points, n_dim = points.shape
k = 0
res_arr = np.empty(n_points, dtype=nb.int64)
for i in range(n_points):
dist = 0.0
for j in range(n_dim):
dist += abs(points[i, j] - center[j]) ** p
if dist < radius ** p:
res_arr[k] = i
k += 1
return res_arr[:k]
@nb.jit(forceobj=True, parallel=True)
def find_nn_jit(points, radius=RADIUS):
n_points, n_dim = points.shape
result = np.empty(n_points, dtype=object)
for i in nb.prange(n_points):
result[i] = neighbors_indexes_jit(radius, points[i], points, 2)
return result
Это тесты, которые я получил (я пропустил scipy.spatial.KDTree()
, потому что это было далеко от графика, в соответствии с вашими выводами):
(для полноты ниже приведен код, необходимый для адаптации шаблона)
def gen_input(n, dim=2, scale=SCALE):
return scale * np.random.rand(n, dim)
def equal_output(a, b):
return all(sorted(a_i) == sorted(b_i) for a_i, b_i in zip(a, b))
funcs = find_nn_np, find_nn_jit, find_nn_kd_tree_cy
input_sizes = tuple(int(2 ** (2 + (1 * i) / 4)) for i in range(32, 32 + 16 + 1))
print('Input Sizes:\n', input_sizes, '\n')
runtimes, input_sizes, labels, results = benchmark(
funcs, gen_input=gen_input, equal_output=equal_output,
input_sizes=input_sizes)
plot_benchmarks(runtimes, input_sizes, labels, units='s')