Исчерпывающая конкатенация между тензорами - PullRequest
0 голосов
/ 14 января 2019

Я пытаюсь сделать исчерпывающую связь между тензорами. Так, например, У меня есть тензор:

a = torch.randn(3, 512)

Я хочу объединить как конкат (t1, t1), конкат (t1, t2), конкат (t1, t3), конкат (t2, t1), конкат (t2, t2) ....

В качестве наивного решения, Я использовал for цикл:

ans = []
result = []
split = torch.split(a, [1, 1, 1], dim=0)

for i in range(len(split)):
    ans.append(split[i])

for t1 in ans:
    for t2 in ans:
        result.append(torch.cat((t1,t2), dim=1))

Проблема в том, что каждая эпоха занимает очень много времени, а код медленный. Я попробовал решение, опубликованное в вопросе PyTorch: Как реализовать внимание для слоя внимания графика , но это дает ошибку памяти.

t1 = a.repeat(1, a.shape[0]).view(a.shape[0] * a.shape[0], -1)
t2 = a.repeat(a.shape[0], 1)
result.append(torch.cat((t1, t2), dim=1))

Я уверен, что есть более быстрый путь, но я не смог его выяснить.

...