Они не действительно одинаковы.Учтите, что у нас есть следующая однонаправленная модель GRU:
import torch.nn as nn
import torch
gru = nn.GRU(input_size = 8, hidden_size = 50, num_layers = 3, batch_first = True)
Пожалуйста, убедитесь, что вы внимательно наблюдаете форму ввода.
inp = torch.randn(1024, 112, 8)
out, hn = gru(inp)
Определенно,
torch.equal(out, hn)
False
Один из наиболее эффективных способов, который помог мне понять выходные данные по сравнению со скрытыми состояниями, заключался в том, чтобы рассматривать hn
как hn.view(num_layers, num_directions, batch, hidden_size)
, где num_directions = 2
для двунаправленных рекуррентных сетей (и еще 1,т.е. наш случай) .Таким образом,
hn_conceptual_view = hn.view(3, 1, 1024, 50)
Как указано в документе (обратите внимание на курсив и жирный шрифт) :
h_n формы (num_layers * num_directions, batch, hidden_size): тензор, содержащий скрытое состояние для t = seq_len (то есть для последнего временного шага)
В нашем случае это содержит скрытый вектор для временного шага t = 112
, гдевывод:
формы (seq_len, batch, num_directions * hidden_size): тензор, содержащий выходные элементы h_t из последнего слоя GRU, для каждого t . Если в качестве входных данных указан torch.nn.utils.rnn.PackedSequence, выходные данные также будут упакованными.Для распакованного случая направления могут быть разделены с помощью output.view (seq_len, batch, num_directions, hidden_size), причем forward и backward - направления 0 и 1. соответственно.
Таким образом, следовательно, можноdo:
torch.equal(out[:, -1], hn_conceptual_view[-1, 0, :, :])
True
Объяснение : я сравниваю последнюю последовательность из всех партий в out[:, -1]
со скрытыми векторами последнего слоя из hn[-1, 0, :, :]
Для Двунаправленный GRU (сначала требуется чтение однонаправленного):
gru = nn.GRU(input_size = 8, hidden_size = 50, num_layers = 3, batch_first = True bidirectional = True)
inp = torch.randn(1024, 112, 8)
out, hn = gru(inp)
Вид изменяется на (поскольку у нас есть два направления):
hn_conceptual_view = hn.view(3, 2, 1024, 50)
Если вы попытаетесьточный код:
torch.equal(out[:, -1], hn_conceptual_view[-1, 0, :, :])
False
Объяснение : Это потому, что мы даже сравниваем неправильные формы;
out[:, 0].shape
torch.Size([1024, 100])
hn_conceptual_view[-1, 0, :, :].shape
torch.Size([1024, 50])
Помните, что для двунаправленных сетей скрытые состояния объединяютсяна каждом временном шаге, где первый размер hidden_state
(т. е. out[:, 0,
:50
]
) - это скрытые состояния для прямой сети, а другой размер hidden_state
- дляв обратном направлении (т. е. out[:, 0,
50:
]
).Правильное сравнение для прямой сети будет таким:
torch.equal(out[:, -1, :50], hn_conceptual_view[-1, 0, :, :])
True
Если вы хотите скрытые состояния для обратной сети и с обратной сетиобрабатывает последовательность с временным шагом n ... 1
.Вы сравниваете первый временной шаг последовательности , но последний hidden_state
размер и меняете направление hn_conceptual_view
на 1
:
torch.equal(out[:, -1, :50], hn_conceptual_view[-1, 1, :, :])
True
В двух словах, как правило,Говоря:
Однонаправленный :
rnn_module = nn.RECURRENT_MODULE(num_layers = X, hidden_state = H, batch_first = True)
inp = torch.rand(B, S, E)
output, hn = rnn_module(inp)
hn_conceptual_view = hn.view(X, 1, B, H)
Где RECURRENT_MODULE
- это ГРУ или LSTM (на момент написания этого поста), B
- партияразмер, S
длина последовательности и E
размер вложения.
torch.equal(output[:, S, :], hn_conceptual_view[-1, 0, :, :])
True
Снова мы использовали S
, поскольку rnn_module
является прямым (то есть однонаправленным), а последний временной шаг сохраняется надлина последовательности S
.
Двунаправленный :
rnn_module = nn.RECURRENT_MODULE(num_layers = X, hidden_state = H, batch_first = True, bidirectional = True)
inp = torch.rand(B, S, E)
output, hn = rnn_module(inp)
hn_conceptual_view = hn.view(X, 2, B, H)
Сравнение
torch.equal(output[:, S, :H], hn_conceptual_view[-1, 0, :, :])
True
Выше приведено прямое сравнение сети, мы использовали :H
потому что форвард хранит свой скрытый вектор в первых H
элементах для каждого временного шага.
Для обратной сети:
torch.equal(output[:, 0, H:], hn_conceptual_view[-1, 1, :, :])
True
Мы изменили направление в hn_conceptual_view
на 1
чтобы получить скрытые векторы для обратной сети.
Для всех примеров мы использовали hn_conceptual_view[-1, ...]
, потому что мыинтересует только последний слой.