Кому делать ранний останов с оценочной потерей, используя tf.estimator.train_and_evaluate? - PullRequest
2 голосов
/ 14 марта 2020

Я использую оценщик Tensorflow и явно метод tf.estimator.train_and_evaluate(). Для тренировки есть крюк ранней остановки, равный tf.contrib.estimator.stop_if_no_decrease_hook, но у меня есть проблема, что потеря тренировки слишком нервная, чтобы использовать ее для ранней остановки. Кто-нибудь знает, как сделать раннюю остановку с tf.estimator на основе потери оценки ?

1 Ответ

1 голос
/ 17 марта 2020

Вы можете использовать tf.contrib.estimator.stop_if_no_decrease_hook , как указано ниже:

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))

Но если это не работает для вас, лучше использовать tf .estimator.experimental.stop_if_no_decrease_hook вместо.

Например:

estimator = ...
# Hook to stop training if loss does not decrease in over 100000 steps.
hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)

Хук с ранней остановкой использует результаты оценки, чтобы решить, когда пора сократить тренировку, но вы необходимо указать количество этапов обучения, которые вы хотите отслеживать, и помнить, сколько оценок произойдет за это количество этапов. Если вы установите хук как hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 10000), хук будет рассматривать оценки, происходящие в диапазоне 10 000 шагов.

Подробнее о документации см. Здесь: https://www.tensorflow.org/api_docs/python/tf/estimator/experimental/stop_if_no_decrease_hook и для всех функций ранней остановки, которые вы можете использовать, вы можете обратиться к этому https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/early_stopping.py

...