Я пытаюсь вставить упакованную и дополненную последовательность через GRU и получить выходные данные последнего элемента каждой последовательности. Конечно, я имею в виду не элемент -1
, а фактический последний элемент без дополнения. Мы заранее знаем длины последовательностей, поэтому должно быть так же просто, как извлечь для каждой последовательности элемент length-1
.
Я попробовал следующее
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# Data
input = torch.Tensor([[[0., 0., 0.],
[1., 0., 1.],
[1., 1., 0.],
[1., 0., 1.],
[1., 0., 1.],
[1., 1., 0.]],
[[1., 1., 0.],
[0., 1., 0.],
[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[1., 0., 0.],
[1., 1., 1.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[1., 1., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]])
lengths = [6, 4, 3, 1]
p = pack_padded_sequence(input, lengths, batch_first=True)
# Forward
gru = torch.nn.GRU(3, 12, batch_first=True)
packed_output, gru_h = gru(p)
# Unpack
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
last_seq_idxs = torch.LongTensor([x-1 for x in input_sizes])
last_seq_items = torch.index_select(output, 1, last_seq_idxs)
print(last_seq_items.size())
# torch.Size([4, 4, 12])
Но форма не та, которую я ожидаю. Я ожидал получить 4x12
, т.е. last item of each individual sequence x hidden
.`
Я мог бы пройтись по всему циклу и создать новый тензор, содержащий нужные мне элементы, но я надеялся на встроенный подход, который использовал бы некоторую умную математику. Боюсь, что ручное создание циклов и сборка приведут к очень низкой производительности.