В отличие от слоя torch.nn.CrossEntropyLoss
, который принимает значения метки для цели (т. Е. Если input
имеет форму (30, C, 96, 96, 96)
с C
числом классов, target
должно быть (30, 96, 96, 96)
), torch.nn.functional.binary_cross_entropy()
необходимо, чтобы input
и target
имели одинаковую форму (т. Е. target
формы (30, C, 96, 96, 96)
), поэтому требуется представление меток назначения в горячем виде.
Если вы не выберете torch.nn.CrossEntropyLoss
, у вас есть несколько способов разогреть целевые ярлыки (например, см. thread ).Персональное решение:
def to_one_hot(x, C=2, tensor_class=torch.FloatTensor):
""" One-hot a batched tensor of shape (B, ...) into (B, C, ...) """
x_one_hot = tensor_class(x.size(0), C, *x.shape[1:]).zero_()
x_one_hot = x_one_hot.scatter_(1, x.unsqueeze(1), 1)
return x_one_hot
# Demonstration:
num_classes = 2
labels = torch.LongTensor(30, 96, 96, 96).random_(0, num_classes)
one_hot_labels = to_one_hot(labels, C=num_classes)
print(one_hot_labels.shape)
# > torch.Size([30, 2, 96, 96, 96])