Векторизация ближайшего соседа - PullRequest
0 голосов
/ 25 июня 2018

Я ищу способ улучшить производительность моей простой функции ближайшего соседа, но я не очень хорошо разбираюсь в векторизации с numpy. Любая помощь будет оценена!

def knn_search(pts_a, pts_b, k):
    """
    Finds the k nearest neighbours of each point in pts_a in pts_b
    :param pts_a:
    :param pts_b:
    :param k:
    :return dist, idx:
    """

    dist = np.empty((pts_b.shape[0], pts_a.shape[0]))
    for i in range(pts_b.shape[0]):
        dist[i, :] = np.linalg.norm(pts_a - pts_b[i, :], axis=1)

    idx = np.argsort(dist, axis=1)
    dist = np.sort(dist, axis=1)

    return dist[:, :k], idx[:, :k]


a = np.random.rand(10, 2)
b = np.random.rand(10, 2)

distance, indices = knn_search(a, b, 5)

1 Ответ

0 голосов
/ 25 июня 2018

Вы можете заменить вашу петлю внешней разницей, используя трансляцию:

def knn_search(pts_a, pts_b, k):
    """
    Finds the k nearest neighbours of each point in pts_a in pts_b
    :param pts_a:
    :param pts_b:
    :param k:
    :return dist, idx:
    """

    dist = np.linalg.norm(pts_a - pts_b[:, None], axis=-1)
    idx = np.argsort(dist, axis=1)
    dist = np.sort(dist, axis=1)

    return dist[:, :k], idx[:, :k]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...