Почему мои две переменные тензорного потока не обновляются синхронно? - PullRequest
0 голосов
/ 04 июня 2019

Я пытаюсь (правильно или нет) написать модифицированную форму оптимизатора Keras SGD с бэкэндом Tensorflow.Идея состоит в том, чтобы запланировать «перезапуски» SGD в определенные эпохи, используя меньшие скорости обучения, но без сохранения и перезагрузки модели.Чтобы скорость затухания для скорости обучения вела себя так, как при таком перезапуске, я хочу отслеживать не только общее количество итераций, но и количество итераций с момента последнего «перезапуска».

Итак, я инициализирую свой счетчик итерации с момента последнего перезапуска (self.iteration_ref) до 0 при создании моего объекта оптимизатора SGD (т. Е. SGD_VAR), так же как self.iterations инициализируется равным 0. Затем с каждой итерацией я увеличиваю каждый счетчикна 1, если только нет сброса, в этом случае я сбрасываю свой счетчик итераций (self.iterations_ref) на 1. Здесь показан код, который я использую (он унаследован от класса SGD Кераса и вносит лишь небольшие изменения:

class SGD_VAR(SGD):
"""Stochastic gradient descent optimizer.

Includes support for momentum,
learning rate decay, and Nesterov momentum.

# Arguments
    lr: float >= 0. Learning rate.
    momentum: float >= 0. Parameter that accelerates SGD
        in the relevant direction and dampens oscillations.
    decay: float >= 0. Learning rate decay over each update.
    nesterov: boolean. Whether to apply Nesterov momentum.
"""

def __init__(self, lr=0.05, momentum=0., decay=0.,
             nesterov=False, lr_dict = {},
             batches_per_epoch = 1562,
             **kwargs):

    super(SGD_VAR, self).__init__(lr, momentum, decay,
                                  nesterov, **kwargs)
    if lr_dict == {}:
        lr_dict = {0:lr}

    self.lr_dict = lr_dict
    self.batches_per_epoch = batches_per_epoch

    with K.name_scope(self.__class__.__name__):
        # Here is where I initialize *MY* iterations counter
        self.iterations_ref = K.variable(0, dtype='int64', 
                                         name='iterations_ref')
        self.new_lr = K.variable(lr, name='new_lr')

@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):


    def lr_stepper(iteration, lr):
        ''' Wrapped python method used by tensor 
            to determine desired learning rate'''

        # Change the learning rate when specified 
        # in lr_dict(dict of epochs: learning rates)
        for x in self.lr_dict:
            temp = tf.Variable((x-1) * self.batches_per_epoch, 
                               dtype=iteration.dtype)
            if tf.equal(temp, iteration):
                return tf.constant(self.lr_dict[x], dtype=lr.dtype)

        return lr

    # NOTE: K.update_add and K.update 
            return tf.assign_add and tf.assign, respectively
    self.updates = [K.update_add(self.iterations, 1)]


    # Key lines to change self.lr
    new_lr = tf.contrib.eager.py_func(func=lr_stepper,
                                      inp=[self.iterations, self.lr], 
                                      Tout=tf.float32)

    # Note: self.lr != new_lr indicates a RESET has occurred
    new_iter_ref = tf.cond(tf.math.equal(self.lr,new_lr),
                           lambda: K.update_add(self.iterations_ref, 1),
                           lambda: K.update(self.iterations_ref, 1))
    self.updates.append(K.update(self.lr, new_lr))
    self.updates.append(new_iter_ref)

    # Temporary code to debug output
    self.iterations = tf.Print(self.lr,
             [self.iterations,self.iterations_ref, self.lr],
                               message="\n Debug Vals:" )

Я использую tf.Print для распечатки self.iterations, self.iterations_ref и self.lr. Каждое число заключено в квадратные скобки. Я бы ожидал, что tf.Print покажет self.iterations и self.iterations_ref быть равными друг другу (исключая эффекты любых перезагрузок), но вместо этого я вижу, что они сохраняют разницу в 1 - т.е. вывод, который я вижу:

 Debug Vals:[1][0][0.1]
 Debug Vals:[2][1][0.1]
 Debug Vals:[3][2][0.1]
 Debug Vals:[4][3][0.1]

...

Я ожидал:

 Debug Vals:[1][1][0.1]
 Debug Vals:[2][2][0.1]
 Debug Vals:[3][3][0.1]
 Debug Vals:[4][4][0.1]

...

Почему это?(Примечание: я использую Keras 2.2.4 и tenorflow 1.8)

...