R в квадрате пользовательских метрик в Tensorflow 2 - PullRequest
0 голосов
/ 29 апреля 2020

Я пытаюсь включить в метод компиляции пользовательскую функцию R2. Однако результаты метрик связаны и не совпадают с результатами при расчете вне метода компиляции. Я использую набор данных Mercedes-Greener-Manufacturing от Kaggle. Не могли бы вы помочь?

    model.compile(optimizer='adam',
                 loss='mean_squared_error',
                 metrics=[R2_tf()])

Вот класс пользовательских метрик

class R2_tf(tf.keras.metrics.Metric):
    def __init__(self, name='R2', **kwargs):
        super(R2_tf, self).__init__(name=name, **kwargs)
        self.R2 = self.add_weight(name='r2', initializer=None)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.squeeze(y_pred)
        y_pred = tf.cast(y_pred, 'float32')

        y_true = tf.cast(y_true, 'float32')

        SSres = tf.reduce_sum(tf.square(tf.math.subtract(y_true, y_pred)))
        SStot = tf.reduce_sum(tf.square(tf.math.subtract(y_true, tf.reduce_mean(y_true))))

        values = tf.subtract(1., tf.divide(SSres, SStot))
        values = tf.cast(values, 'float32')

        self.R2.assign(values)

    def result(self):
        return self.R2

    def reset_states(self):
      # The state of the metric will be reset at the start of each epoch.
      self.R2.assign(0.)
...