Вы можете использовать sort
:
import torch
# fake pairwise distance matrix, M=3, N=4
x = torch.rand((3,4))
print(x)
# tensor([[0.7667, 0.6847, 0.3779, 0.3007],
# [0.9881, 0.9909, 0.3180, 0.5389],
# [0.6341, 0.8095, 0.4214, 0.7216]])
closest = torch.sort(x, dim=-1) # default is -1, but I prefer to be clear
# let's say you want the k=2 closest points
k=2
closest_k_values = closest[0][:, :k]
closest_k_indices = closest[1][:, :k]
print(closest_k_values)
# tensor([[0.3007, 0.3779],
# [0.3180, 0.5389],
# [0.4214, 0.6341]])
print(closest_k_indices)
# tensor([[3, 2],
# [2, 3],
# [2, 0]])