def generate_mask(data : list, max_seq_len : int):
"""
Generates a mask for data where each element is expected to be max_seq_len length after padding
Args:
data : The data being forwarded through LSTM after being converted to a tensor
max_seq_len : The length of the names after being padded
"""
batch_sz = len(data)
ret = torch.zeros(1,batch_sz, max_seq_len, dtype=torch.bool)
for i in range(batch_sz):
name = data[i]
for letter_idx in range(len(name)):
ret[0][i][letter_idx] = 1
return ret
У меня есть этот код для генерации маски, и я действительно ненавижу, как я это делаю. По сути, как вы можете видеть, я просто перебираю каждое имя и изменяю каждый индекс с 0 на длину имени до 1, я бы предпочел более элегантный способ сделать это.