x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_norm = (y**2).sum(1).view(1, -1)
else:
y = x
y_norm = x_norm.view(1, -1)
dist = (x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)))
return dist
Выше приведен код, используемый для вычисления матрицы парных расстояний (M * N) между x (M точек) и y (N точек).
Я надеюсь создать матрицу парных расстояний, которая имеет 0 элемент, когда расстояние между двумя точками больше указанного c значения 'T'.
В таком случае, что мне делать?
Спасибо