Пользовательская функция потерь M CC в M XNet (коэффициент корреляции Мэтьюса) - PullRequest
1 голос
/ 29 марта 2020

Я пытаюсь написать функцию потери M CC. Я набросал макет, но это должно быть сделано в форме манипуляции с матрицей, как показано в справочном материале 1 (это не школьная программа). Поэтому следующий код является псевдокодом того, что я пытаюсь сделать.

class MCCLoss(Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(MCCLoss, self).__init__(weight, batch_axis, **kwargs)

    @staticmethod
    def compute_confusion_matrix_values(y_true, y_pred):
        tp = 0
        fp = 0
        tn = 0
        fn = 0

        for i in range(len(y_pred)):
            if y_true[i] == y_pred[i] == 1:
                tp += 1
            if y_pred[i] == 1 and y_true[i] != y_pred[i]:
                fp += 1
            if y_true[i] == y_pred[i] == 0:
                tn += 1
            if y_pred[i] == 0 and y_true[i] != y_pred[i]:
                fn += 1

        return tp, fp, tn, fn

    @staticmethod
    def matthews_corrcoef(F, tp, fp, tn, fn):
        # https://stackoverflow.com/a/56875660/992687
        x = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
        epsilon = np.finfo(np.float64).eps
        return ((tp * tn) - (fp * fn)) / F.sqrt(x + epsilon)

    def hybrid_forward(self, F, y_pred, y_true, sample_weight=None):
        tp, fp, tn, fn = self.compute_confusion_matrix_values(y_true, y_pred)
        loss = 1 - self.matthews_corrcoef(F, tp, fp, tn, fn)
        return loss

Я нашел некоторые ресурсы, которые являются весьма полезными, особенно пример реализации с использованием Keras в следующей ссылке Ссылка 1.

Я не уверен, что смогу использовать MakeLoss в Справочнике 2 для упрощения всего.

Справочник:

  1. Множественная классификация для M CC (Реализация коэффициента корреляции Мэтьюса) в Keras https://github.com/vlainic/matthews-correlation-coefficient/blob/master/multi_mcc_loss.py

  2. Пользовательская функция потери с использованием MakeLoss http://beta.mxnet.io/r/api/mx.symbol.MakeLoss.html https://blog.csdn.net/u013381011/article/details/79141680

  3. M XNet M CC metri c https://github.com/apache/incubator-mxnet/blob/56e79853ad5cf98baf84454eb595c7658bef6ee6/python/mxnet/metric.py#L838

Будет ли любой технический эксперт, помогите мне реализовать это? Нужны хорошие руки

Очень ценится Очень ценится

...