Как передать вход в слой Pytorch LSTM - PullRequest
0 голосов
/ 21 мая 2018

Я новичок в PyTorch и не понимаю, как приспособить сеть, используя эту платформу.

У меня есть простая модель в Keras:

model = Sequential()
model.add(Embedding(vocab_size, embedding_size, input_length=55, weights=[pretrained_weights]))
model.add(Bidirectional(LSTM(units=len(X_train))))
model.add(Dense(n_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',
          optimizer = RMSprop(lr=0.0005),
          metrics=['accuracy'])

model.fit(np.array(X_train), np.array(y_train), epochs=100, validation_data=(np.array(X_val), np.array(y_val)))

В Keras я могупросто напишите в сети мой X_train (2D-массив, содержащий индексы) и мой y_train (2D-массив, содержащий один индекс для каждого входа).

Теперь, чтобы прокормить свою модель PyTorch, я преобразовал свою матрицу втакой тензор:

M = torch.tensor(X_train)

И определил мою сеть:

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, num_classes):
    super(BiRNN, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
    self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection

  def forward(self, x):
    # Set initial states
    h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
    c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)

    # Forward propagate LSTM
    out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)

    # Decode the hidden state of the last time step
    out = self.fc(out[:, -1, :])
    return out

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Но я не понимаю, как вызвать функцию и использовать мои собственные данные для предсказаний.

...