Реализация модели LSTM - PullRequest
0 голосов
/ 26 апреля 2020
class LSTM(nn.Module):
    def __init__(self, input_size=1, output_size=1, hidden_size=100, num_layers=16):
        super().__init__()
        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)

        self.linear = nn.Linear(hidden_size, output_size)

        self.num_layers = num_layers

        self.hidden_cell = (torch.zeros(self.num_layers,12 ,self.hidden_size).to(device),
                            torch.zeros(self.num_layers,12 ,self.hidden_size).to(device))

    def forward(self, input_seq):
        #lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq) ,1, -1), self.hidden_cell)
        lstm_out, self.hidden_cell = self.lstm(input_seq, self.hidden_cell)        
        predictions = self.linear(lstm_out[:,-1,:])

        return predictions

Это моя модель LSTM, Input - это 4-мерный вектор. Размер партии равен 16, а отметка времени равна 12. Я хочу найти 13-й вектор с использованием вектора 12 последовательностей. Мой блок LSTM имеет выход [16,12,48]. Я не понял, почему я выбрал последний: out[:,-1,:]

1 Ответ

0 голосов
/ 26 апреля 2020

На первый взгляд, ваша проблема похожа на проблему классификации текста (т. Е. Последовательности), а output_size - это число классов, которым вы хотите назначить текст. Выбрав lstm_out[:,-1,:], вы фактически намереваетесь предсказать метку, связанную с вводимым текстом, только с использованием последнего скрытого состояния вашей сети LSTM, что полностью имеет смысл. Это то, что люди обычно делают для задач классификации текста. После этого ваш линейный слой будет выводить логиты для каждого класса, а затем вы можете использовать nn.Softmax() для получения вероятностей этих классов.

Последнее скрытое состояние сети LSTM - это распространение всех предыдущих скрытых состояний. LSTM, что означает, что он содержит агрегированную информацию о предыдущих состояниях ввода, которые он закодировал (давайте рассмотрим, что вы используете однонаправленный LSTM, как в вашем примере). Таким образом, для классификации входного текста вам нужно будет выполнить классификацию на основе общей информации обо всех токенах во входном тексте (который закодирован в последнем скрытом состоянии вашего LSTM). Вот почему вы передаете только последнее скрытое состояние линейному слою, который находится в вашей сети LSTM.

Примечание: Если вы намеревались выполнить маркировку последовательности (например, распознавание именованных сущностей), то вы бы использовали все выходные данные скрытого состояния из вашей сети LSTM. В таких задачах вам фактически потребуется информация об указанном c токене во входных данных.

...