Заменить все индексы в тензоре в диапазоне на 1 с - PullRequest
1 голос
/ 07 января 2020
def generate_mask(data : list, max_seq_len : int):
    """
    Generates a mask for data where each element is expected to be max_seq_len length after padding
    Args:
    data : The data being forwarded through LSTM after being converted to a tensor
    max_seq_len : The length of the names after being padded
    """
    batch_sz = len(data)
    ret = torch.zeros(1,batch_sz, max_seq_len, dtype=torch.bool)
    for i in range(batch_sz):
        name = data[i]

        for letter_idx in range(len(name)):
            ret[0][i][letter_idx] = 1

    return ret

У меня есть этот код для генерации маски, и я действительно ненавижу, как я это делаю. По сути, как вы можете видеть, я просто перебираю каждое имя и изменяю каждый индекс с 0 на длину имени до 1, я бы предпочел более элегантный способ сделать это.

1 Ответ

3 голосов
/ 07 января 2020

Ну, вы можете упростить что-то вроде этого:

# [...]
for i in range(batch_sz):
    ret[0, i, :len(data[i])] = 1
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...