PyTorch - применяя внимание эффективно - PullRequest
0 голосов
/ 10 декабря 2018

Я построил модель языка RNN с вниманием и создаю контекстный вектор для каждого элемента ввода, посещая все предыдущие скрытые состояния (только одно направление) .

Наиболее прямолинейное решение, на мой взгляд, заключается в использовании for-loop поверх вывода RNN, так что каждый вектор контекста вычисляется один за другим.

import torch
import torch.nn as nn
import torch.nn.functional as F

class RNN_LM(nn.Module):
    def __init__(self, hidden_size, vocab_size, embedding_dim=None, droprate=0.5):
        super().__init__()
        if not embedding_dim:
            embedding_dim = hidden_size
        self.embedding_matrix = nn.Embedding(vocab_size, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, batch_first=False)
        self.attn = nn.Linear(hidden_size, hidden_size)
        self.vocab_dist = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(droprate)

    def forward(self, x):
        x = self.dropout(self.embedding_matrix(x.view(-1, 1)))
        x, states = self.lstm(x)
        #print(x.size())
        x = x.squeeze()
        content_vectors = [x[0].view(1, -1)]
        # for-loop over hidden states and attention
        for i in range(1, x.size(0)):
            prev_states = x[:i]
            current_state = x[i].view(1, -1)

            attn_prod = torch.mm(self.attn(current_state), prev_states.t())
            attn_weights = F.softmax(attn_prod, dim=1)
            context = torch.mm(attn_weights, prev_states)
            content_vectors.append(context)

        return self.vocab_dist(self.dropout(torch.cat(content_vectors)))

Примечание: метод forward здесь используется только для обучения.

Однако это решение не очень эффективно, поскольку код плохо распараллеливается свычисление каждого вектора контекста последовательно.Но поскольку векторы контекста не зависят друг от друга, мне интересно, существует ли непоследовательный способ их вычисления.

Так есть способ вычисления векторов контекста без цикл for , чтобы можно было распараллелить больше вычислений?

1 Ответ

0 голосов
/ 11 декабря 2018

Хорошо, для ясности: я полагаю, что мы действительно заботимся о векторизации цикла for.Какая форма x?Предполагая, что x является двумерным, у меня есть следующий код, где v1 выполняет ваш цикл, а v2 - векторизованная версия:

import torch
import torch.nn.functional as F

torch.manual_seed(0)

x = torch.randn(3, 6)

def v1():
    for i in range(1, x.size(0)):
        prev = x[:i]
        curr = x[i].view(1, -1)

        prod = torch.mm(curr, prev.t())
        attn = prod # same shape
        context = torch.mm(attn, prev)
        print(context)

def v2():
    # we're going to unroll the loop by vectorizing over the new,
    # 0-th dimension of `x`. We repeat it as many times as there
    # are iterations in the for loop
    repeated = x.unsqueeze(0).repeat(x.size(0), 1, 1)

    # we're looking to build a `prevs` tensor such that
    # prevs[i, x, y] == prev[x, y] at i-th iteration of the loop in v1,
    # up to 0-padding necessary to make them all the same size.
    # We need to build a higher-dimensional equivalent of torch.triu
    xs = torch.arange(x.size(0)).reshape(1, -1, 1)
    zs = torch.arange(x.size(0)).reshape(-1, 1, 1)
    prevs = torch.where(zs < xs, torch.tensor(0.), repeated)

    # this is an equivalent of the above iteration starting at 1
    prevs = prevs[:-1]
    currs = x[1:]

    # a batched matrix multiplication
    prod = torch.matmul(currs, prevs.transpose(1, 2))
    attn = prod # same shape
    context = torch.matmul(attn, prevs)
    # equivalent of a higher dimensional torch.diagonal
    contexts = torch.einsum('iij->ij', (context))
    print(contexts)

print(x)

print('\n------ v1 -------\n')
v1()
print('\n------ v2 -------\n')
v2()

, которая векторизует ваш цикл, с некоторыми оговорками.Во-первых, я предполагаю, что x является 2-мерным.Во-вторых, я пропускаю softmax, утверждая, что он не меняет размер входных данных и, следовательно, не влияет на векторизацию.Это правда, но, к сожалению, softmax 0-дополненного вектора v не равен 0-дополненному softmax unpadded v.Это можно исправить с помощью перенормировки.Пожалуйста, дайте мне знать, если мои предположения верны и является ли это хорошей отправной точкой для вашей работы.

...