Я пытался понять, как работает механизм внимания. В настоящее время рассматривается пример tf js -examples / date-Conversion-Внимание . Я обнаружил , что в примере используется dot product alignment score
(из Effective Approaches to Attention-based Neural Machine Translation
).
Итак, это выражение представляется как
let attention = tf.layers.dot({axes: [2, 2]}).apply([decoder, encoder]);
в коде . Декодер (h_t
) имеет форму [10,64], а кодер (h_s
) - [12,64], поэтому результат будет иметь форму [10,12]. Пока все хорошо.
Теперь я пытаюсь реализовать concat alignment score
, который выглядит следующим образом .
Итак, первое, что нужно сделать, это объединить h_t
и h_s
. Тем не менее, они имеют разные формы, поэтому я не знаю, как поступить. Должен ли я как-то изменить тензоры? Если так, то какой была бы форма?
Я гуглил, чтобы узнать, как другие люди делают это, и нашел это .
#For concat scoring, decoder hidden state and encoder outputs are concatenated first
out = torch.tanh(self.fc(decoder_hidden+encoder_outputs))
Но это не кажется правильным, так как они суммируют значения вместо объединения.
Любое руководство будет оценено.
ОБНОВЛЕНИЕ Вот краткое описание модели:
__________________________________________________________________________________________________
Layer (type) Output shape Param # Receives inputs
==================================================================================================
input1 (InputLayer) [null,12] 0
__________________________________________________________________________________________________
embedding_Embedding1 (Embedding [null,12,64] 2240 input1[0][0]
__________________________________________________________________________________________________
input2 (InputLayer) [null,10] 0
__________________________________________________________________________________________________
lstm_LSTM1 (LSTM) [null,12,64] 33024 embedding_Embedding1[0][0]
__________________________________________________________________________________________________
embedding_Embedding2 (Embedding [null,10,64] 832 input2[0][0]
__________________________________________________________________________________________________
encoderLast (GetLastTimestepLay [null,64] 0 lstm_LSTM1[0][0]
__________________________________________________________________________________________________
lstm_LSTM2 (LSTM) [null,10,64] 33024 embedding_Embedding2[0][0]
encoderLast[0][0]
encoderLast[0][0]
__________________________________________________________________________________________________
dot_Dot1 (Dot) [null,10,12] 0 lstm_LSTM2[0][0]
lstm_LSTM1[0][0]
__________________________________________________________________________________________________
attention (Activation) [null,10,12] 0 dot_Dot1[0][0]
__________________________________________________________________________________________________
context (Dot) [null,10,64] 0 attention[0][0]
lstm_LSTM1[0][0]
__________________________________________________________________________________________________
concatenate_Concatenate1 (Conca [null,10,128] 0 context[0][0]
lstm_LSTM2[0][0]
__________________________________________________________________________________________________
time_distributed_TimeDistribute [null,10,64] 8256 concatenate_Concatenate1[0][0]
__________________________________________________________________________________________________
time_distributed_TimeDistribute [null,10,13] 845 time_distributed_TimeDistributed1
==================================================================================================
Total params: 78221
Trainable params: 78221
Non-trainable params: 0
__________________________________________________________________________________________________