Я понимаю ваше замешательство. Посмотрите на приведенный ниже пример и комментарии:
# [Batch size, Sequence length, Embedding size]
inputs = torch.rand(128, 5, 300)
gru = nn.GRU(input_size=300, hidden_size=400, num_layers=2, batch_first=True)
with torch.no_grad():
# output is all hidden states, for each element in the batch of the last layer in the RNN
# a is the last hidden state of the first layer
# b is the last hidden state of the second (last) layer
output, (a, b) = gru(inputs)
Если мы распечатаем фигуры, они подтвердят наше понимание:
print(output.shape) # torch.Size([128, 5, 400])
print(a.shape) # torch.Size([128, 400])
print(b.shape) # torch.Size([128, 400])
Кроме того, мы можем проверить, является ли последний скрытымсостояние для каждого элемента в пакете последнего слоя, полученного из output
, равно b
:
np.testing.assert_almost_equal(b.numpy(), output[:,:-1,:].numpy())
Наконец, мы можем создать RNN с 3 слоями и запустить тот жеtests:
gru = nn.GRU(input_size=300, hidden_size=400, num_layers=3, batch_first=True)
with torch.no_grad():
output, (a, b, c) = gru(inputs)
np.testing.assert_almost_equal(c.numpy(), output[:,-1,:].numpy())
Опять утверждение проходит, но только если мы сделаем это для c
, который теперь является последним уровнем RNN. В противном случае:
np.testing.assert_almost_equal(b.numpy(), output[:,-1,:].numpy())
Выдает ошибку:
AssertionError: Массивы почти не равны 7 десятичным знакам
Я надеюсь, что это проясняет ситуациюдля вас.