Если размер пакета может быть получен из len(labels)
:
def to_onehot(labels, n_categories, dtype=torch.float32):
batch_size = len(labels)
one_hot_labels = torch.zeros(size=(batch_size, n_categories), dtype=dtype)
for i, label in enumerate(labels):
# Subtract 1 from each LongTensor because your
# indexing starts at 1 and tensor indexing starts at 0
label = torch.LongTensor(label) - 1
one_hot_labels[i] = one_hot_labels[i].scatter_(dim=0, index=label, value=1.)
return one_hot_labels
, и у вас есть 6 категорий, и вы хотите, чтобы выходной сигнал был тензором целых чисел :
to_onehot(labels, n_categories=6, dtype=torch.int64)
tensor([[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 1],
[1, 0, 0, 0, 0, 0],
[1, 0, 0, 1, 1, 0],
[0, 0, 0, 1, 0, 0]])
Я бы придерживался torch.float32
на тот случай, если вы захотите использовать сглаживание меток, перепутывание или что-то в этом духе позже.