Добавьте аргумент ignore_index в функцию JaccardLoss (IoU) - PullRequest
0 голосов
/ 27 марта 2020

Я пытаюсь обучить модель CNN сегментации семанти c. Target size равно [32, 1, 99, 99], а output of model равно [32, 6, 99, 99], поскольку num_class равно 6. когда я использую JaccardLoss fun c во время тренировки, он говорит index 255 is out of bounds for dimension 0 with size 6. Теперь в моих ярлыках есть некоторые точки, где нет класса, поэтому для этих точек используйте 255. Эта проблема решается torch.nn.CrossEntropy с аргументом ignore_index при установке на 255. Но, поскольку я пытаюсь использовать функцию JaccardLoss, 255 создает ошибку, поскольку она не игнорируется. Так может кто-нибудь, пожалуйста, помогите мне с кодом, чтобы игнорировать 255. Я прилагаю функцию JaccardLoss, которую я использую

def jaccard_loss(true, logits, eps=1e-7):
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the jaccard loss so we
    return the negated jaccard loss.
    Args:
        true: a tensor of shape [B, H, W] or [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
        eps: added to the denominator for numerical stability.
    Returns:
        jacc_loss: the Jaccard loss.
    """
    num_classes = logits.shape[1]
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(logits)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
    else:
        true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        probas = F.softmax(logits, dim=1)
    true_1_hot = true_1_hot.type(logits.type())
    dims = (0,) + tuple(range(2, true.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    union = cardinality - intersection
    jacc_loss = (intersection / (union + eps)).mean()
    return (1 - jacc_loss)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...