После того, как моя первая версия с использованием цикла for оказалась неэффективной, это самое быстрое решение, которое я придумал, для двух тензоров одинакового размера prediction
и truth
:
def confusion(prediction, truth):
confusion_vector = prediction / truth
true_positives = torch.sum(confusion_vector == 1).item()
false_positives = torch.sum(confusion_vector == float('inf')).item()
true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
false_negatives = torch.sum(confusion_vector == 0).item()
return true_positives, false_positives, true_negatives, false_negatives
Прокомментированная версия и тест-кейс на https://gist.github.com/the-bass/cae9f3976866776dea17a5049013258d