Вот моя реализация, использующая torch.nn.utils.rnn.pad_sequence()
:
in_tensor = torch.rand((9, 3))
print(in_tensor)
print(36*'=')
lengths = torch.tensor([3, 4, 2])
cum_len = 0
y = []
for idx, val in enumerate(lengths):
y.append(in_tensor[cum_len : cum_len+val])
cum_len += val
print(torch.nn.utils.rnn.pad_sequence(y, batch_first=True)))
вывод:
# in_tensor of shape (9 x 3)
tensor([[0.9169, 0.3549, 0.6211],
[0.4832, 0.5475, 0.8862],
[0.8708, 0.5462, 0.9374],
[0.4605, 0.1167, 0.5842],
[0.1670, 0.2862, 0.0378],
[0.2438, 0.5742, 0.4907],
[0.1045, 0.5294, 0.5262],
[0.0805, 0.2065, 0.2080],
[0.6417, 0.4479, 0.0688]])
====================================
# out tensor of shape (len(lengths) x max(lengths) x b), in this case b is 3
tensor([[[0.9169, 0.3549, 0.6211],
[0.4832, 0.5475, 0.8862],
[0.8708, 0.5462, 0.9374],
[0.0000, 0.0000, 0.0000]],
[[0.4605, 0.1167, 0.5842],
[0.1670, 0.2862, 0.0378],
[0.2438, 0.5742, 0.4907],
[0.1045, 0.5294, 0.5262]],
[[0.0805, 0.2065, 0.2080],
[0.6417, 0.4479, 0.0688],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]]])