У меня есть тензоры X формы BxNxD
и Y формы BxNxD
.
Я хочу вычислить попарные расстояния для каждого элемента в пакете, т.е. я тензор BxMxN
.
Как мне это сделать?
Здесь обсуждается эта тема: https://github.com/pytorch/pytorch/issues/9406,, но я не понимаю этого, так как есть много деталей реализации, в то время как фактическое решение не выделено..
Наивным подходом было бы использовать ответ для парных расстояний без группировки, как описано здесь: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065, то есть
import torch
import numpy as np
B = 32
N = 128
M = 256
D = 3
X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))
def pairwise_distances(x, y=None):
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_t = torch.transpose(y, 0, 1)
y_norm = (y**2).sum(1).view(1, -1)
else:
y_t = torch.transpose(x, 0, 1)
y_norm = x_norm.view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
return torch.clamp(dist, 0.0, np.inf)
out = []
for b in range(B):
out.append(pairwise_distances(X[b], Y[b]))
print(torch.stack(out).shape)
Как я могу сделать это без зацикливанияB?Спасибо