Отображение списка меток с одним горячим кодированием - PullRequest
1 голос
/ 22 марта 2020

Как я могу сделать то же самое, когда 'label' - это список? Например: label = [2,4,6,1,7 ..., 9]

label = 3
NumClass = 10
NumRows  = 100

mask    =  torch.zeros(100, 64)
ones     =  torch.ones(1, 64)
ElementsPerClass = NumRows//NumClass
mask [ ElementsPerClass*label : ElementsPerClass*(label+1) ] = ones

1 Ответ

0 голосов
/ 22 марта 2020

Вы ищете scatter:

NumRows = len(label)
mask = torch.zeros((NumRoes, NumClass)).scatter_(dim=1, index=torch.tensor(label, dtype=torch.long)[:, None], src=torch.ones(NumRows, 1))
...