В модели seq2seq кодер кодирует входные последовательности, заданные в виде мини-пакетов.Скажем, например, входное значение равно B x S x d
, где B - размер пакета, S - максимальная длина последовательности и d - размер вложения слова.Тогда выходной сигнал кодера равен B x S x h
, где h - размер скрытого состояния кодера (который представляет собой RNN).
Теперь при декодировании (во время обучения) входные последовательности задаются по одной за раз., поэтому ввод равен B x 1 x d
, а декодер выдает тензор формы B x 1 x h
.Теперь, чтобы вычислить вектор контекста, нам нужно сравнить скрытое состояние этого декодера с кодированными состояниями кодера.
Итак, рассмотрим, что у вас есть два тензора формы T1 = B x S x h
и T2 = B x 1 x h
.Так что, если вы можете выполнить пакетное матричное умножение следующим образом.
out = torch.bmm(T1, T2.transpose(1, 2))
По существу, вы умножаете тензор формы B x S x h
на тензор формы B x h x 1
, и это приведет к B x S x 1
, который являетсявес внимания для каждой партии.
Здесь вес внимания B x S x 1
представляет показатель сходства между текущим скрытым состоянием декодера и всеми скрытыми состояниями кодера.Теперь вы можете взять веса внимания для умножения на скрытое состояние кодировщика B x S x h
путем транспонирования в первую очередь, и это даст тензор формы B x h x 1
.И если вы выполните сжатие при dim = 2, вы получите тензор формы B x h
, который является вашим контекстным вектором.
Этот вектор контекста (B x h
) обычно объединяется со скрытым состоянием декодера (* 1028)*, нажмите dim = 1), чтобы предсказать следующий токен.