Очень похоже на функцию binary_accuracy
, определенную в руководстве, вы можете реализовать любую метрику c, какую захотите. Все, что вам нужно, - это набор предсказаний модели (в данном случае preds
) и истинных целей (y
).
Например, для матрицы путаницы вы можете сделать следующее:
from sklearn.metrics import confusion_matrix
def compute_confusion_matrix(preds, y):
#round predictions to the closest integer
rounded_preds = torch.round(torch.sigmoid(preds))
return confusion_matrix(y, rounded_preds)