Как вычислить попарное расстояние между набором точек и линиями в PyTorch? - PullRequest
1 голос
/ 01 ноября 2019

Набор точек A представляет собой матрицу Nx3, и из двух наборов точек B и C с одинаковым размером Mx3 мы могли бы получить линии BC между ними. Теперь я хочу вычислить расстояние от каждой точки в A до каждой линии в BC. B - это Mx3, а C - это Mx3, тогда линии идут от точек с соответствующими рядами, поэтому BC - это матрица Mx3. Основной метод вычисляется следующим образом:

D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
    p = A[i]  # 1x3
    for j in range(M):
        p1 = B[j] # 1x3
        p2 = C[j] # 1x3
        D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2) 

Есть ли более быстрый метод для выполнения этой работы? Благодаря.

1 Ответ

3 голосов
/ 01 ноября 2019

Вы можете удалить петли for, выполнив это (это должно ускориться за счет памяти, если M и N не малы):

diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines

Конечно,вам не нужно делать это шаг за шагом. Я просто попытался прояснить имена переменных.

Примечание : если вы не предоставите dim для torch.cross, он будет использовать первый dim=3, который дастневерные результаты, если N=3 (из документов ):

Если dim не указан, по умолчанию используется первое найденное измерение с размером 3.

Если вам интересно, вы можете проверить здесь , почему я выбрал expand вместо repeat.

...