Вы можете удалить петли for
, выполнив это (это должно ускориться за счет памяти, если M
и N
не малы):
diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines
Конечно,вам не нужно делать это шаг за шагом. Я просто попытался прояснить имена переменных.
Примечание : если вы не предоставите dim
для torch.cross
, он будет использовать первый dim=3
, который дастневерные результаты, если N=3
(из документов ):
Если dim не указан, по умолчанию используется первое найденное измерение с размером 3.
Если вам интересно, вы можете проверить здесь , почему я выбрал expand
вместо repeat
.