Хорошо, для ясности: я полагаю, что мы действительно заботимся о векторизации цикла for
.Какая форма x
?Предполагая, что x
является двумерным, у меня есть следующий код, где v1
выполняет ваш цикл, а v2
- векторизованная версия:
import torch
import torch.nn.functional as F
torch.manual_seed(0)
x = torch.randn(3, 6)
def v1():
for i in range(1, x.size(0)):
prev = x[:i]
curr = x[i].view(1, -1)
prod = torch.mm(curr, prev.t())
attn = prod # same shape
context = torch.mm(attn, prev)
print(context)
def v2():
# we're going to unroll the loop by vectorizing over the new,
# 0-th dimension of `x`. We repeat it as many times as there
# are iterations in the for loop
repeated = x.unsqueeze(0).repeat(x.size(0), 1, 1)
# we're looking to build a `prevs` tensor such that
# prevs[i, x, y] == prev[x, y] at i-th iteration of the loop in v1,
# up to 0-padding necessary to make them all the same size.
# We need to build a higher-dimensional equivalent of torch.triu
xs = torch.arange(x.size(0)).reshape(1, -1, 1)
zs = torch.arange(x.size(0)).reshape(-1, 1, 1)
prevs = torch.where(zs < xs, torch.tensor(0.), repeated)
# this is an equivalent of the above iteration starting at 1
prevs = prevs[:-1]
currs = x[1:]
# a batched matrix multiplication
prod = torch.matmul(currs, prevs.transpose(1, 2))
attn = prod # same shape
context = torch.matmul(attn, prevs)
# equivalent of a higher dimensional torch.diagonal
contexts = torch.einsum('iij->ij', (context))
print(contexts)
print(x)
print('\n------ v1 -------\n')
v1()
print('\n------ v2 -------\n')
v2()
, которая векторизует ваш цикл, с некоторыми оговорками.Во-первых, я предполагаю, что x
является 2-мерным.Во-вторых, я пропускаю softmax
, утверждая, что он не меняет размер входных данных и, следовательно, не влияет на векторизацию.Это правда, но, к сожалению, softmax 0-дополненного вектора v
не равен 0-дополненному softmax unpadded v
.Это можно исправить с помощью перенормировки.Пожалуйста, дайте мне знать, если мои предположения верны и является ли это хорошей отправной точкой для вашей работы.