pytorch: Как использовать вывод модели ГРУ? - PullRequest
0 голосов
/ 06 августа 2020

Модель GRU в pytorch выводит два объекта: выходные функции, а также скрытые состояния. Я понимаю, что для классификации используются выходные функции, но я не совсем уверен, какие из них. В частности, в типичной архитектуре декодер-кодер, которая использует GRU в части декодера, обычно передается только последний (по времени, т. Е. T = N, где N - длина входной последовательности) вывод на кодер. . Какая часть выходного тензора относится к этому временному последнему выходу?

GRU создается таким образом (обратите внимание, что он двунаправленный):

self.gru = nn.GRU(
            700,
            700,
            bidirectional=True,
            batch_first=True,
        )

Учитывая некоторый вектор внедрения, кусок текста размером 150x700, я использую GRU так (150 - длина последовательности, 700 - размер встраивания):

gru_out, gru_hidden = self.gru(embedding)

gru_out будет иметь форму 150x1400, где 150 - это снова длина последовательности и 1400 вдвое превышает размер встраивания, потому что GRU является двунаправленным (с точки зрения документации pytorch, hidden_size * num_directions).

Если я хочу получить доступ только к последнему выводу по времени, могу ли я нужно получить к нему доступ вот так?

tmp = gru_out.view(150, 2, 700)
last_out_first_direction = tmp[149, 0, :]
last_out_second_direction = tmp[149, 1, :]

Хотя технически это кажется правильным и похоже на ответ, опубликованный здесь , также потребуется, чтобы фактическая входная последовательность всегда имела длину 150 , тогда как обычно у вас также есть более короткие фактические входные последовательности, которые просто дополняются до длины 150. Однако в GRU обычно одна лежит в последнем фактическом входном токене, который, таким образом, также может находиться в позиции <150. Каков общий способ доступа к фактическому последнему токену или временному шагу (<= 150) вместо только технически последнего шага (всегда = 150)? </p>

Дополнительный вопрос: реверсируется ли вывод второго направления (поскольку направление, в котором информация проходит через ГРУ, также меняется на противоположное по сравнению с первым направлением), поэтому я должен фактически получить доступ к last_out_second_direction = tmp[0, 1, :] вместо tmp[149, 1, :]?

...