Я пытаюсь включить в метод компиляции пользовательскую функцию R2. Однако результаты метрик связаны и не совпадают с результатами при расчете вне метода компиляции. Я использую набор данных Mercedes-Greener-Manufacturing от Kaggle. Не могли бы вы помочь?
model.compile(optimizer='adam',
loss='mean_squared_error',
metrics=[R2_tf()])
Вот класс пользовательских метрик
class R2_tf(tf.keras.metrics.Metric):
def __init__(self, name='R2', **kwargs):
super(R2_tf, self).__init__(name=name, **kwargs)
self.R2 = self.add_weight(name='r2', initializer=None)
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.squeeze(y_pred)
y_pred = tf.cast(y_pred, 'float32')
y_true = tf.cast(y_true, 'float32')
SSres = tf.reduce_sum(tf.square(tf.math.subtract(y_true, y_pred)))
SStot = tf.reduce_sum(tf.square(tf.math.subtract(y_true, tf.reduce_mean(y_true))))
values = tf.subtract(1., tf.divide(SSres, SStot))
values = tf.cast(values, 'float32')
self.R2.assign(values)
def result(self):
return self.R2
def reset_states(self):
# The state of the metric will be reset at the start of each epoch.
self.R2.assign(0.)