Как пакетно конвертировать длины предложений в маски в PyTorch? - PullRequest
0 голосов
/ 21 ноября 2018

Например, из

lens = [3, 5, 4]

мы хотим получить

mask = [[1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0]]

Оба из которых torch.LongTensor с.

Ответы [ 2 ]

0 голосов
/ 13 июня 2019

torch.arange(max_len)[None, :] < lens[:, None]

0 голосов
/ 21 ноября 2018

Один из способов, который я нашел:

torch.arange(max_len).expand(len(lens), max_len) < lens.unsqueeze(1)

Пожалуйста, поделитесь, если есть лучшие способы!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...