Как изменить скорость обучения в Tensorflow в зависимости от количества пакетов и эпох? - PullRequest
1 голос
/ 06 августа 2020

Есть ли возможность реализовать следующий сценарий с Tensorflow:

В первых N пакетах скорость обучения должна быть увеличена с 0 до 0,001. После того, как это количество пакетов будет достигнуто, скорость обучения должна медленно уменьшаться с 0,001 до 0,00001 после каждой эпохи.

Как я могу объединить эту комбинацию в обратном вызове? Tensorflow предлагает tf.keras.callbacks.LearningRateScheduler и функции обратного вызова on_train_batch_begin () или on_train_batch_end (). Но я не буду рассматривать общую комбинацию этих обратных вызовов.

Может ли кто-нибудь дать мне подход, как создать такой комбинированный обратный вызов, который зависит от количества пакетов и эпох?

1 Ответ

1 голос
/ 06 августа 2020

Что-нибудь вроде , это сработает. Я не тестировал это и не пытался доводить до совершенства ... но детали есть, так что вы можете заставить его работать так, как вам нравится.

import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import numpy as np

class LRSetter(Callback):
    
    def __init__(self, start_lr=0, middle_lr=0.001, end_lr=0.00001, 
                 start_mid_batches=200, end_epochs=2000):
        
        self.start_mid_lr = np.linspace(start_lr, middle_lr, start_mid_batches)
        #Not exactly right since you'll have gone through a couple epochs
        #but you get the picture
        self.mid_end_lr = np.linspace(middle_lr, end_lr, end_epochs) 
        
        self.start_mid_batches = start_mid_batches
        
        self.epoch_takeover = False
        
    def on_train_batch_begin(self, batch, logs=None):
    
        if batch < self.start_mid_batches:
            tf.keras.backend.set_value(self.model.optimizer.lr, self.start_mid_lr[batch])
        else:
            self.epoch_takeover = True

    def on_epoch_begin(self, epoch):
        if self.epoch_takeover:
            tf.keras.backend.set_value(self.model.optimizer.lr, self.mid_end_lr[epoch])
...