Модель GRU в pytorch выводит два объекта: выходные функции, а также скрытые состояния. Я понимаю, что для классификации используются выходные функции, но я не совсем уверен, какие из них. В частности, в типичной архитектуре декодер-кодер, которая использует GRU в части декодера, обычно передается только последний (по времени, т. Е. T = N, где N - длина входной последовательности) вывод на кодер. . Какая часть выходного тензора относится к этому временному последнему выходу?
GRU создается таким образом (обратите внимание, что он двунаправленный):
self.gru = nn.GRU(
700,
700,
bidirectional=True,
batch_first=True,
)
Учитывая некоторый вектор внедрения, кусок текста размером 150x700, я использую GRU так (150 - длина последовательности, 700 - размер встраивания):
gru_out, gru_hidden = self.gru(embedding)
gru_out будет иметь форму 150x1400, где 150 - это снова длина последовательности и 1400 вдвое превышает размер встраивания, потому что GRU является двунаправленным (с точки зрения документации pytorch, hidden_size * num_directions).
Если я хочу получить доступ только к последнему выводу по времени, могу ли я нужно получить к нему доступ вот так?
tmp = gru_out.view(150, 2, 700)
last_out_first_direction = tmp[149, 0, :]
last_out_second_direction = tmp[149, 1, :]
Хотя технически это кажется правильным и похоже на ответ, опубликованный здесь , также потребуется, чтобы фактическая входная последовательность всегда имела длину 150 , тогда как обычно у вас также есть более короткие фактические входные последовательности, которые просто дополняются до длины 150. Однако в GRU обычно одна лежит в последнем фактическом входном токене, который, таким образом, также может находиться в позиции <150. Каков общий способ доступа к фактическому последнему токену или временному шагу (<= 150) вместо только технически последнего шага (всегда = 150)? </p>
Дополнительный вопрос: реверсируется ли вывод второго направления (поскольку направление, в котором информация проходит через ГРУ, также меняется на противоположное по сравнению с первым направлением), поэтому я должен фактически получить доступ к last_out_second_direction = tmp[0, 1, :]
вместо tmp[149, 1, :]
?