Как бы вы сделали ReduceLROnPlateau в Tensorflow? - PullRequest
0 голосов
/ 10 мая 2018

Keras имеет функцию обратного вызова, которая снижает скорость обучения при плато по указанному метрике, называемую ReduceLROnPlateau.

Как создать такую ​​функцию в собственном Tensorflow?Можно ли в модели Tensorflow вызывать обратные вызовы Keras?Или это нужно писать на родном Tensorflow?Если да, то как бы вы установили скорость обучения в середине тренировки?

Ответы [ 2 ]

0 голосов
/ 17 апреля 2019

Вот не преобразование 1: 1 из Keras 'ReduceLROnPlateau', которое я написал. Он проверяет потерю каждой партии, а не случайную выборку в конце каждой эпохи. Время перезарядки и терпение все еще в эпоху, хотя. Его можно использовать так же, как tf.train.exponential_decay (...).

Я думаю, что, возможно, есть лучший способ сделать это, чем просто контролировать минимальное значение потерь, поскольку минимальное значение может быть экстремальным выбросом. Метрика с точки зрения некоторого скользящего среднего градиента потерь может быть лучше.

def plateau_decay(learning_rate, global_step, loss, data_count, batch_size, factor=0.1, patience=10, min_delta=1e-4, cooldown=0, min_lr=0):
steps_per_epoch = math.ceil(data_count // batch_size)
patient_steps = patience * steps_per_epoch
cooldown_steps = cooldown * steps_per_epoch

if not isinstance(learning_rate, tf.Tensor):
    learning_rate = tf.get_variable('learning_rate', initializer=tf.constant(learning_rate), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])

with tf.variable_scope('plateau_decay'):
    step = tf.get_variable('step', trainable=False, initializer=global_step, collections=[tf.GraphKeys.LOCAL_VARIABLES])
    best = tf.get_variable('best', trainable=False, initializer=tf.constant(np.Inf, tf.float32), collections=[tf.GraphKeys.LOCAL_VARIABLES])

    def _update_best():
        with tf.control_dependencies([
            tf.assign(best, loss),
            tf.assign(step, global_step),
            tf.print('Plateau Decay: Updated Best - Step:', global_step, 'Next Decay Step:', global_step + patient_steps, 'Loss:', loss)
        ]):
            return tf.identity(learning_rate)

    def _decay():
        with tf.control_dependencies([
            tf.assign(best, loss),
            tf.assign(learning_rate, tf.maximum(tf.multiply(learning_rate, factor), min_lr)),
            tf.assign(step, global_step + cooldown_steps),
            tf.print('Plateau Decay: Decayed LR - Step:', global_step, 'Next Decay Step:', global_step + cooldown_steps + patient_steps, 'Learning Rate:', learning_rate)
        ]):
            return tf.identity(learning_rate)

    def _no_op(): return tf.identity(learning_rate)

    met_threshold = tf.less(loss, best - min_delta)
    should_decay = tf.greater_equal(global_step - step, patient_steps)

    return tf.cond(met_threshold, _update_best, lambda: tf.cond(should_decay, _decay, _no_op))
0 голосов
/ 10 мая 2018

Боюсь, тензор потока не поддерживает это из коробки (и обратные вызовы keras также не применимы напрямую). Вот список поддерживаемых методов планирования скорости обучения : все они представляют собой разные алгоритмы, но являются автономными , т.е. не зависят от результатов обучения.

Но хорошая новость заключается в том, что все оптимизаторы принимают тензор для скорости обучения. Таким образом, вы можете создать переменную или заполнитель для скорости обучения и изменить ее значение в зависимости от эффективности валидации (которую вам также необходимо рассчитать самостоятельно). Вот пример из этого чудесного ответа :

learning_rate = tf.placeholder(tf.float32, shape=[])
# ...
train_step = tf.train.GradientDescentOptimizer(
    learning_rate=learning_rate).minimize(mse)

sess = tf.Session()

# Feed different values for learning rate to each training step.
sess.run(train_step, feed_dict={learning_rate: 0.1})
sess.run(train_step, feed_dict={learning_rate: 0.1})
sess.run(train_step, feed_dict={learning_rate: 0.01})
sess.run(train_step, feed_dict={learning_rate: 0.01})
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...