Вы можете использовать tf.contrib.estimator.stop_if_no_decrease_hook , как указано ниже:
estimator = tf.estimator.Estimator(model_fn, model_dir)
os.makedirs(estimator.eval_dir()) # TODO This should not be expected IMO.
early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
estimator,
metric_name='loss',
max_steps_without_decrease=1000,
min_steps=100)
tf.estimator.train_and_evaluate(
estimator,
train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
eval_spec=tf.estimator.EvalSpec(eval_input_fn))
Но если это не работает для вас, лучше использовать tf .estimator.experimental.stop_if_no_decrease_hook вместо.
Например:
estimator = ...
# Hook to stop training if loss does not decrease in over 100000 steps.
hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
Хук с ранней остановкой использует результаты оценки, чтобы решить, когда пора сократить тренировку, но вы необходимо указать количество этапов обучения, которые вы хотите отслеживать, и помнить, сколько оценок произойдет за это количество этапов. Если вы установите хук как hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 10000)
, хук будет рассматривать оценки, происходящие в диапазоне 10 000 шагов.
Подробнее о документации см. Здесь: https://www.tensorflow.org/api_docs/python/tf/estimator/experimental/stop_if_no_decrease_hook и для всех функций ранней остановки, которые вы можете использовать, вы можете обратиться к этому https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/early_stopping.py