Понимание формы ввода в PyTorch LSTM - PullRequest
1 голос
/ 06 мая 2020

Кажется, это один из самых распространенных вопросов о LSTM в PyTorch, но я все еще не могу понять, какой должна быть форма ввода для PyTorch LSTM.

Даже после нескольких сообщений ( 1 , 2 , 3 ) и пробуя решения, похоже, что это не работает.

Справочная информация. Я закодировал текстовые последовательности (переменной длины) в пакете размером 12, и эти последовательности дополняются и упаковываются с использованием функциональности pad_packed_sequence. MAX_LEN для каждой последовательности - 384, и каждый токен (или слово) в последовательности имеет размерность 768. Следовательно, мой пакетный тензор может иметь одну из следующих форм: [12, 384, 768] или [384, 12, 768].

Пакет будет моим входом в модуль PyTorch rnn (здесь lstm).

Согласно документации PyTorch для LSTM , его входные размеры составляют (seq_len, batch, input_size), что я понимаю следующим образом.
seq_len - количество временных шагов в каждом входном потоке (длина вектора признаков).
batch - размер каждого пакета входных последовательностей.
input_size - размерность каждого входного токена или шаг по времени.

lstm = nn.LSTM(input_size=?, hidden_size=?, batch_first=True)

Какие здесь должны быть точные значения input_size и hidden_size?

1 Ответ

2 голосов
/ 07 мая 2020

Вы объяснили структуру своего ввода, но вы не установили связь между вашими входными размерами и ожидаемыми входными размерами LSTM.

Давайте разберем ваш вход (присвоение имен измерениям):

  • batch_size: 12
  • seq_len: 384
  • input_size / num_features: 768

Это означает, что input_size LSTM должно быть 768.

hidden_size не зависит от вашего ввода, а скорее от того, сколько функций должен создать LSTM, который затем также используется для скрытого состояния в качестве вывода, поскольку это последнее скрытое состояние. Вы должны решить, сколько функций вы хотите использовать для LSTM.

Наконец, для формы ввода, настройка batch_first=True требует, чтобы ввод имел форму [batch_size, seq_len, input_size], в вашем случае это будет [12, 384, 768].

import torch
import torch.nn as nn

# Size: [batch_size, seq_len, input_size]
input = torch.randn(12, 384, 768)

lstm = nn.LSTM(input_size=768, hidden_size=512, batch_first=True)

output, _ = lstm(input)
output.size()  # => torch.Size([12, 384, 512])
...