Обратный вызов Early Stop в keras - PullRequest
3 голосов
/ 29 мая 2020

Как можно эффективно остановить процесс подгонки модели обучения с помощью обратного вызова в keras? До сих пор я пробовал различные подходы, в том числе приведенный ниже.

class EarlyStoppingCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold):
        super(EarlyStoppingCallback, self).__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None):
        accuracy = logs["accuracy"]
        if accuracy >= self.threshold:
            print("Stopping early!")
            self.model.stop_training = True

Обратный вызов выполняется, однако self.model.stop_training = True, похоже, не действует. Печать выполнена успешно, но модель продолжает обучение. Есть идеи, как решить эту проблему? Моя версия тензорного потока: tenorflow == 1.14.0

1 Ответ

1 голос
/ 02 июня 2020

Вероятно, у вас возникла следующая проблема: https://github.com/tensorflow/tensorflow/issues/37587.

Короче говоря - всякий раз, когда вызывается model.predict или model.evaluate, model.stop_training сбрасывается на False. Я смог воспроизвести это поведение, используя ваш EarlyStoppingCallback, за которым последовал другой обратный вызов, который вызывал model.predict для некоторого фиксированного набора данных.

Обходной путь заключается в том, чтобы сначала поместить обратные вызовы, вызывающие model.predict или model.evaluate, перед любыми обратными вызовами, которые могут захотеть установить model.stop_training на True. Также похоже, что проблема была исправлена ​​в TF 2.2.

...