У меня есть этот код из учебника PyTorch на seq2seq с вниманием. С https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
def forward(self, input, hidden, encoder_outputs):
"""
Run LSTM through 1 time step
SHAPE REQUIREMENT
- input: <1 x batch_size x N_LETTER>
- hidden: (<num_layer x batch_size x hidden_size>, <num_layer x batch_size x hidden_size>)
- lstm_out: <1 x batch_size x N_LETTER>
"""
# Incorporate attention to LSTM input
hidden_cat = torch.cat((hidden[0], hidden[1]), dim=2)
attn_weights = F.softmax(self.attn(torch.cat((input, hidden_cat), 2)), dim=2)
attn_applied = torch.bmm(attn_weights.transpose(0,1),encoder_outputs.transpose(0,1)).transpose(0,1)
attn_output = torch.cat((input, attn_applied), 2)
attn_output = F.relu(self.attn_combine(attn_output))
# Run LSTM
lstm_out, hidden = self.lstm(attn_output, hidden)
lstm_out = self.fc1(lstm_out)
lstm_out = self.softmax(lstm_out)
return lstm_out, hidden
В настоящее время я пытаюсь замаскировать тензор attn_weights, который равен 1 x размеру пакета x максимальной длине имени. Несмотря на то, что учебник представляет собой один seq2seq с предложениями, я делаю seq2seq на уровне символов имен, поэтому «AB» будет встроен как
[[1,0,...],
[0,1,...]]
Предположим, что A = 0 и B = 1
Таким образом, входные данные - это тензор, который раньше был именами, а attn_weights кажется весами для каждого индекса символа. Мне нужно найти способ установить эти веса в attn_weights на отрицательную бесконечность, когда индекс соотносится с символом pad. Так, например, имя «Diddy» и, скажем, максимальная длина имени равно 6, тогда оно станет «Diddy», просто представьте, что это один символ. Таким образом, основная проблема заключается в том, что все имена имеют разное количество отступов. Есть ли элегантный способ добавить маску к отступу в attn_weights? Единственный способ, о котором я могу подумать, - передать список длины каждого имени, а затем для каждого пакета в attn_weights просто изменить все индексы> длина имени на отрицательную бесконечность, но я чувствую, что это довольно дерьмово.