Я пытаюсь сделать исчерпывающую связь между тензорами. Так, например,
У меня есть тензор:
a = torch.randn(3, 512)
Я хочу объединить как
конкат (t1, t1), конкат (t1, t2), конкат (t1, t3), конкат (t2, t1), конкат (t2, t2) ....
В качестве наивного решения,
Я использовал for
цикл:
ans = []
result = []
split = torch.split(a, [1, 1, 1], dim=0)
for i in range(len(split)):
ans.append(split[i])
for t1 in ans:
for t2 in ans:
result.append(torch.cat((t1,t2), dim=1))
Проблема в том, что каждая эпоха занимает очень много времени, а код медленный.
Я попробовал решение, опубликованное в вопросе PyTorch: Как реализовать внимание для слоя внимания графика , но это дает ошибку памяти.
t1 = a.repeat(1, a.shape[0]).view(a.shape[0] * a.shape[0], -1)
t2 = a.repeat(a.shape[0], 1)
result.append(torch.cat((t1, t2), dim=1))
Я уверен, что есть более быстрый путь, но я не смог его выяснить.