Pytorch разрезал 2d массив по длине в первом измерении - PullRequest
1 голос
/ 04 марта 2020

У меня есть двумерный массив, скажем, размером torch.tensor(batch_size, 1000). Массив 1000 из второго измерения на самом деле переменной длины. У меня есть второй массив размером [batch_size], содержащий длину для каждой строки ...

Вот пример кода:

# preds is the 2d array of size [batch_size, 1000]
# lengths is a 1d array containing the lengths of each row of preds
res_pred = []
for i in range(len(preds)):
    length = lengths[i].item()
    res_pred += [preds[i][:length]]

result = torch.cat(res_pred).flatten()

Я делаю то же самое для своих целей и тогда я могу применить функцию потерь к обоим.

Мне было интересно, если бы была одна векторизованная операция, которую я мог бы сделать, чтобы извлечь все batch_size векторов переменной длины и torch.cat их вместе. Прямо сейчас я зацикливаюсь на первом измерении, но это чувствуется медленно.

Спасибо,

1 Ответ

1 голос
/ 05 марта 2020

Вы можете создать тензор двухмерной маски с количеством True в i-й строке, заданным как lengths[i]. Вот один пример:

batch_size = 6
n = 5

preds = torch.arange(batch_size * n).reshape(batch_size, n)
# tensor([[ 0,  1,  2,  3,  4],
#         [ 5,  6,  7,  8,  9],
#         [10, 11, 12, 13, 14],
#         [15, 16, 17, 18, 19],
#         [20, 21, 22, 23, 24],
#         [25, 26, 27, 28, 29]])

#lengths = np.random.randint(0, n+1, batch_size)
lengths = torch.randint(0, n+1, (batch_size, ))
# tensor([2, 0, 5, 3, 3, 2])

Давайте создадим маску и получим наш результат (возможно, есть лучший способ создать такую ​​маску, но я придумал это):

#mask = np.tile(range(n), (batch_size,1)) < lengths[:,None]
mask = torch.arange(n).repeat((batch_size,1)) < lengths[:, None]
# tensor([[ True,  True, False, False, False],
#        [False, False, False, False, False],
#        [ True,  True,  True,  True,  True],
#        [ True,  True,  True, False, False],
#        [ True,  True,  True, False, False],
#        [ True,  True, False, False, False]])

#result = preds[mask]
result = torch.masked_select(preds, mask)
# tensor([0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 25, 26])

Это дает тот же результат, что и ваш код:

res_pred = []
for i in range(len(preds)):
    length = lengths[i].item()
    res_pred += [preds[i][:length]]

result = torch.cat(res_pred).flatten()
...