Я тренирую модель с Keras 2.1.3 и Tensorflow 1.8.0, используя Keras fit_generator.Я определяю пользовательскую метрику, которую я использовал для контрольной точки модели: AUPRC (область под кривой точного возврата).Эта метрика добавляется в список метрик при компиляции модели
def as_keras_metric(method):
import functools
from keras import backend as K
import tensorflow as tf
@functools.wraps(method)
def wrapper(self, args, **kwargs):
""" Wrapper for turning tensorflow metrics into keras metrics """
value, update_op = method(self, args, **kwargs)
K.get_session().run(tf.local_variables_initializer())
with tf.control_dependencies([update_op]):
value = tf.identity(value)
return value
return wrapper
@as_keras_metric
def AUPRC(y_true, y_pred, curve='PR'):
return tf.metrics.auc(y_true, y_pred, curve=curve,summation_method='careful_interpolation')
Я добавляю AUPRC в список метрик при компиляции модели:
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy',AUPRC])
Во время обучения модели с помощью fit_generator ();Я замечаю что-то странное.Аттестация AUPRC (для всего набора значений Val), которая рассчитывается после каждого конца эпохи, сообщается метриками fit_generator: 0,7 после эпохи 1, 0,8 после эпохи 2, 0,85 эпохи 3;Затем я прекращаю обучение модели;что-то не так, поскольку эти числа намного выше, чем я ожидаю для этого набора данных
Затем я использовал обученную, но (ранее остановленную) модель, чтобы затем прогнозировать на том же наборе проверки ;используйте прогнозы модели для повторного расчета метрики AUPRC.Теперь я получаю ответ 0,45. Каким-то образом обученная модель сообщает о значительно сниженной производительности по сравнению с тем, что сообщается во время обучения модели (даже меньше, чем то, что мы видим после Эпохи 1)
Кажется, что во время тренировки с Keras fit_generator;Существует некоторая ошибка в Keras или Tensorflow, которая приводит к тому, что неверные значения сообщаются для проверки AUPRC.Или это ошибка в том, как метрика тензорного потока компилируется и используется в качестве метрики кераса?
Примечание: Люди могут сказать, что в различных реализациях AUPRC есть ошибки, поэтому я проверил это число двумя способами, Scikit Learn (AUC для sklearn.metrics.precision_recall_curve и сравнил это с использованиемtenorflow.metrics.auc (кривая PR)) и они оба дают мне 0,45)