Как эффективно рассчитать путаницу в PyTorch? - PullRequest
0 голосов
/ 30 мая 2018

У меня есть тензор, содержащий мои прогнозы, и тензор, который содержит фактические метки для моей проблемы двоичной классификации.Как эффективно рассчитать матрицу путаницы?

1 Ответ

0 голосов
/ 30 мая 2018

После того, как моя первая версия с использованием цикла for оказалась неэффективной, это самое быстрое решение, которое я придумал, для двух тензоров одинакового размера prediction и truth:

def confusion(prediction, truth):
    confusion_vector = prediction / truth

    true_positives = torch.sum(confusion_vector == 1).item()
    false_positives = torch.sum(confusion_vector == float('inf')).item()
    true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
    false_negatives = torch.sum(confusion_vector == 0).item()

    return true_positives, false_positives, true_negatives, false_negatives

Прокомментированная версия и тест-кейс на https://gist.github.com/the-bass/cae9f3976866776dea17a5049013258d

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...