Я пытаюсь обучить модель CNN сегментации семанти c. Target size
равно [32, 1, 99, 99]
, а output of model
равно [32, 6, 99, 99]
, поскольку num_class
равно 6
. когда я использую JaccardLoss fun c во время тренировки, он говорит index 255 is out of bounds for dimension 0 with size 6
. Теперь в моих ярлыках есть некоторые точки, где нет класса, поэтому для этих точек используйте 255
. Эта проблема решается torch.nn.CrossEntropy
с аргументом ignore_index
при установке на 255. Но, поскольку я пытаюсь использовать функцию JaccardLoss, 255
создает ошибку, поскольку она не игнорируется. Так может кто-нибудь, пожалуйста, помогите мне с кодом, чтобы игнорировать 255. Я прилагаю функцию JaccardLoss, которую я использую
def jaccard_loss(true, logits, eps=1e-7):
"""Computes the Jaccard loss, a.k.a the IoU loss.
Note that PyTorch optimizers minimize a loss. In this
case, we would like to maximize the jaccard loss so we
return the negated jaccard loss.
Args:
true: a tensor of shape [B, H, W] or [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
eps: added to the denominator for numerical stability.
Returns:
jacc_loss: the Jaccard loss.
"""
num_classes = logits.shape[1]
if num_classes == 1:
true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot_f = true_1_hot[:, 0:1, :, :]
true_1_hot_s = true_1_hot[:, 1:2, :, :]
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
pos_prob = torch.sigmoid(logits)
neg_prob = 1 - pos_prob
probas = torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(logits, dim=1)
true_1_hot = true_1_hot.type(logits.type())
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
union = cardinality - intersection
jacc_loss = (intersection / (union + eps)).mean()
return (1 - jacc_loss)