Как я могу эффективно определить окрестности из парной матрицы расстояний? - PullRequest
0 голосов
/ 31 января 2020

У меня есть матрица парных расстояний M * N между M точками из группы A и N точками из группы B.

Я хочу получить список соседних точек из группы B для каждой точки из группы A.

Есть ли эффективный код для этой проблемы с использованием pytorch? вместо нескольких 'за' л oop.

Спасибо

1 Ответ

0 голосов
/ 31 января 2020

Вы можете использовать sort:

import torch

# fake pairwise distance matrix, M=3, N=4
x = torch.rand((3,4))
print(x)
# tensor([[0.7667, 0.6847, 0.3779, 0.3007],
#         [0.9881, 0.9909, 0.3180, 0.5389],
#         [0.6341, 0.8095, 0.4214, 0.7216]])

closest = torch.sort(x, dim=-1)  # default is -1, but I prefer to be clear

# let's say you want the k=2 closest points
k=2
closest_k_values = closest[0][:, :k]
closest_k_indices = closest[1][:, :k]

print(closest_k_values)
# tensor([[0.3007, 0.3779],
#         [0.3180, 0.5389],
#         [0.4214, 0.6341]])

print(closest_k_indices)
# tensor([[3, 2],
#         [2, 3],
#         [2, 0]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...