Как настроить пользовательские метрики keras, которые будут вызываться только в конце эпохи? - PullRequest
0 голосов
/ 11 февраля 2019

Я пытаюсь использовать пользовательские метрики для моей нейронной сети, и эта метрика должна оцениваться только в конце эпохи.Проблема, с которой я сталкиваюсь, заключается в том, что показатели оцениваются в каждом пакете, что не соответствует требуемому поведению.Обратите внимание, что я работаю с генераторами и fit_generator с keras.

validation_data загружается с генератором, который реализует keras.utils.Sequence

class DataGenerator(keras.utils.Sequence): 
   def __init__(self, inputs, labels, batch_size):
    self.inputs = inputs
    self.labels = labels
    self.batch_size = batch_size

   def __getitem__(self, index):
    #some processing done here
    return batch_inputs, batch_labels

   def __len__(self):
    return int(np.floor(len(self.inputs) / self.batch_size))

Я пытался реализовать то, что предлагает документация keras, ноЯ не нашел никакой информации, чтобы указать, что метрика должна использоваться только в конце эпохи.

def auc_roc(y_true, y_pred):
   auc, up_opt = tf.metrics.auc(y_true, y_pred)
   K.get_session().run(tf.local_variables_initializer())
   with tf.control_dependencies([up_opt]):
       auc = tf.identity(auc)
   return auc

Так что сейчас auc_roc вызывается после каждого пакета вместо одного вызова в концеepoch.

1 Ответ

0 голосов
/ 11 февраля 2019
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class IntervalEvaluation(Callback):
    def __init__(self, validation_data=(), interval=10):
        super(Callback, self).__init__()

        self.interval = interval
        self.X_val, self.y_val = validation_data

    def on_epoch_end(self, epoch, logs={}):
        if epoch % self.interval == 0:
            y_pred = self.model.predict_proba(self.X_val, verbose=0)
            score = roc_auc_score(self.y_val, y_pred)
            print("interval evaluation - epoch: {:d} - score: {:.6f}".format(epoch, score))

Использование:

ival = IntervalEvaluation(validation_data=(x_test2, y_test2), interval=1)

Подробнее: http://digital -thinking.de / keras-три способа использования пользовательской проверки-метрики-в-керасе/

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...