Я пытаюсь реализовать Bidirectional LSTM в PyTorch, и реализация метода forward выглядит следующим образом:
def forward(self, inputs):
word_embeddings = self.embedding.forward(inputs)
output = torch.zeros((inputs.shape[0], inputs.shape[1], self.tagset_size))
#init hidden states for LSTM
hidden_state_L = []
hidden_state_R = []
hidden_state_L.append(torch.zeros(inputs.shape[1], self.lstm_hidden_dim))
hidden_state_R.append(torch.zeros(inputs.shape[1], self.lstm_hidden_dim))
if next(self.parameters()).is_cuda:
output = output.cuda()
hidden_state_L[0] = hidden_state_L[0].cuda()
hidden_state_R[0] = hidden_state_R[0].cuda()
for i, word in enumerate(word_embeddings):
hidden_state_L.append(self.tanh.forward(self.fc_xh_L.forward(word) + self.fc_hh_L.forward(hidden_state_L[-1])))
for i, word in enumerate(reversed(word_embeddings)):
hidden_state_R.append(self.tanh.forward(self.fc_xh_R.forward(word) + self.fc_hh_R.forward(hidden_state_R[-1])))
hidden_state_L.pop()
hidden_state_R = hidden_state_R[1:][::-1]
for idx, (h_L, h_R) in enumerate(zip(hidden_state_L, hidden_state_R)):
output[idx - 1] = self.fc_hy.forward(torch.cat([h_L, h_R], dim = 1))
return output
В моей реализации скрытые состояния хранятся в списках, а затем из списка я получаю активации. Процесс обучения составляет полторы эпохи, достигая точности примерно 80%, а затем потери нейронной сети становятся NaN и обучение не выполняется. Где потенциально опасное место, или, может быть, можно предложить более разумный путь, чем эти списки? Или могут быть некоторые оговорки с вычислениями на устройстве?
Буду признателен за возможные предложения