Вы можете создать тензор двухмерной маски с количеством True в i-й строке, заданным как lengths[i]
. Вот один пример:
batch_size = 6
n = 5
preds = torch.arange(batch_size * n).reshape(batch_size, n)
# tensor([[ 0, 1, 2, 3, 4],
# [ 5, 6, 7, 8, 9],
# [10, 11, 12, 13, 14],
# [15, 16, 17, 18, 19],
# [20, 21, 22, 23, 24],
# [25, 26, 27, 28, 29]])
#lengths = np.random.randint(0, n+1, batch_size)
lengths = torch.randint(0, n+1, (batch_size, ))
# tensor([2, 0, 5, 3, 3, 2])
Давайте создадим маску и получим наш результат (возможно, есть лучший способ создать такую маску, но я придумал это):
#mask = np.tile(range(n), (batch_size,1)) < lengths[:,None]
mask = torch.arange(n).repeat((batch_size,1)) < lengths[:, None]
# tensor([[ True, True, False, False, False],
# [False, False, False, False, False],
# [ True, True, True, True, True],
# [ True, True, True, False, False],
# [ True, True, True, False, False],
# [ True, True, False, False, False]])
#result = preds[mask]
result = torch.masked_select(preds, mask)
# tensor([0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 25, 26])
Это дает тот же результат, что и ваш код:
res_pred = []
for i in range(len(preds)):
length = lengths[i].item()
res_pred += [preds[i][:length]]
result = torch.cat(res_pred).flatten()