Как я могу эффективно изменить / сделать попарно матрицу расстояний? - PullRequest
1 голос
/ 03 февраля 2020
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y = x
        y_norm = x_norm.view(1, -1)
    dist = (x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)))
    return dist

Выше приведен код, используемый для вычисления матрицы парных расстояний (M * N) между x (M точек) и y (N точек).

Я надеюсь создать матрицу парных расстояний, которая имеет 0 элемент, когда расстояние между двумя точками больше указанного c значения 'T'.

В таком случае, что мне делать?

Спасибо

1 Ответ

3 голосов
/ 03 февраля 2020

Я думаю, что вы ищете torch.where:

new_dist = troch.where(dist > T, dist, 0.)
...