Эпоха управления с пользовательским оценщиком - PullRequest
0 голосов
/ 07 июня 2019

Моя функция ввода выглядит так:

    def input_fn():
        dataset.repeat(epochs).batch(16)
    estimator_model.train(input_fn, steps)

Как я могу сообщить своей модели, что это n-й повтор (эпоха) по набору данных? Я хотел бы реализовать такие вещи, как снижение скорости обучения, модель обучения без состязательной потери для первых n эпох и т. Д. Я использую tf.data.Dataset и tf.estimator.Estimator. Если я вызываю метод поезда несколько раз:

    def input_fn():
        dataset.batch(16)
    for epoch in range(epochs):
        estimator_model.train(input_fn, steps)

будет перестроена модель (разные веса, разные каталоги контрольных точек, разные журналы тензорного потока) - для меня это неприемлемо.

До оценки я бы сделал:

for epoch in range(epochs):
    for iter, data in enumerate(dataset):
        model.train(data, epoch)

Теперь такой код глубоко в духе Estimator и Dataset, и я не могу его контролировать - поэтому мне трудно делать такие вещи, как снижение скорости обучения и т. Д. (Сделать что-то для первой / последней n эпох ).

1 Ответ

0 голосов
/ 07 июня 2019

Если вы знаете размер вашего поезда, вы можете установить параметр steps_per_epoch = train_size//batch_size.Затем в вашем model_fn запросе тензор global_step = tf.train.get_global_step() и затем получите количество эпох, прошедших как тензор epochs_passed = tf.cast(global_step, tf.float32)/steps_per_epoch.

Для многих приложений, таких как график обучения, который вы упомянули, часто бывает идиоматичнеепросто используйте tf.train.piecewise_constant_decay, основанный на аналогичной концепции.

...