Как извлечь параметр из пользовательского оптимизатора в Tensorflow? - PullRequest
1 голос
/ 30 октября 2019

Я новичок как в python, так и в глубоком изучении, и я пытаюсь изучить поведение двухэтапного переключения с Adam на оптимизатор SGD с исходным кодом, который я нашел на github. Оптимизатор предназначен для переключения на SGD при выполнении условия запуска. Условия запуска следующие:

cond_update = gen_math_ops.logical_or(gen_math_ops.logical_and(gen_math_ops.logical_and( self.iterations > 1,  lg_err < 1e-2 ),   lam_t > 0 ), cond )[0]

Методом проб и ошибок я могу найти значение 1e-2, которое позволяет переключение произойти. Я хотел извлечь значение lg_err для построения графика и изучить его значение на всех итерациях. Я добавил несколько строк в исходный код и провел обучение в течение 10 эпох, пытаясь извлечь значение lg_err, но я получил следующий выходной файл csv. csv output

Этомодифицированный исходный код пользовательского оптимизатора, который я добавил в строку 8, 32, 71-74:

from tensorflow.python.framework import ops
from tensorflow.python.keras import optimizers
from tensorflow.python.keras import backend as K
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
import csv

class SWATS(optimizers.Optimizer):
    def __init__(self,lr=0.001,lr_boost=10.0,beta_1=0.9,beta_2=0.999,epsilon=None,decay=0.,amsgrad=False,**kwargs):
        super(SWATS, self).__init__(**kwargs)
        with K.name_scope(self.__class__.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            self.beta_1 = K.variable(beta_1, name='beta_1')
            self.beta_2 = K.variable(beta_2, name='beta_2')
            self.decay = K.variable(decay, name='decay')
        if epsilon is None:
            epsilon = K.epsilon()
        self.epsilon = epsilon
        self.initial_decay = decay
        self.amsgrad = amsgrad

    def get_updates(self, loss, params):
        def m_switch(pred, tensor_a, tensor_b):
            def f_true(): return tensor_a
            def f_false(): return tensor_b
            return control_flow_ops.cond(pred, f_true, f_false, strict=True)
        grads = self.get_gradients(loss, params)
        self.updates = []
        Trial = [] #create trial

        lr = self.lr
        if self.initial_decay > 0:
            lr = lr * ( 1. / (1. + self.decay * math_ops.cast(self.iterations,K.dtype(self.decay))) )

        with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
            t = math_ops.cast(self.iterations, K.floatx())
        lr_bc = gen_math_ops.sqrt(1. - math_ops.pow(self.beta_2, t)) / (1. - math_ops.pow(self.beta_1, t))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        lams = [K.zeros(1, dtype=K.dtype(p)) for p in params]
        conds = [K.variable(False, dtype='bool') for p in params]
        if self.amsgrad:
            vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        else:
            vhats = [K.zeros(1) for _ in params]
        self.weights = [self.iterations] + ms + vs + vhats + lams + conds

        for p, g, m, v, vhat, lam, cond in zip(params, grads, ms, vs, vhats, lams, conds):
            beta_g = m_switch(cond, 1.0, 1.0 - self.beta_1)
            m_t = (self.beta_1 * m) + beta_g * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g)
            if self.amsgrad:
                vhat_t = math_ops.maximum(vhat, v_t)
                p_t_ada = lr_bc * m_t / (gen_math_ops.sqrt(vhat_t) + self.epsilon)
                self.updates.append(state_ops.assign(vhat, vhat_t))
            else:
                p_t_ada = lr_bc * m_t / (gen_math_ops.sqrt(v_t) + self.epsilon)
            gamma_den = math_ops.reduce_sum(p_t_ada * g)
            gamma = math_ops.reduce_sum(gen_math_ops.square(p_t_ada)) / (math_ops.abs(gamma_den) + self.epsilon) * (gen_math_ops.sign(gamma_den) + self.epsilon)
            lam_t = (self.beta_2 * lam) + (1. - self.beta_2) * gamma
            lam_prime = lam / (1. - math_ops.pow(self.beta_2, t))
            lam_t_prime = lam_t / (1. - math_ops.pow(self.beta_2, t))
            lg_err = math_ops.abs( lam_t_prime - gamma )

            # extract lg_err values into array

            Trial.append(lg_err)
            with open('lg_err_values', 'w', newline='') as myfile:
                wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
                wr.writerow(Trial)

            cond_update = gen_math_ops.logical_or(gen_math_ops.logical_and(gen_math_ops.logical_and( self.iterations > 1,  lg_err < 1e-4 ),   lam_t > 0 ), cond )[0]
            lam_update = m_switch(cond_update, lam, lam_t)
            self.updates.append(state_ops.assign(lam, lam_update))
            self.updates.append(state_ops.assign(cond, cond_update))

            p_t_sgd = (1. - self.beta_1) * lam_prime * m_t

            self.updates.append(state_ops.assign(m, m_t))
            self.updates.append(state_ops.assign(v, v_t))

            new_p = m_switch(cond, p - lr * p_t_sgd, p - lr * p_t_ada)

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(state_ops.assign(p, new_p))




        return self.updates

    def get_config(self):
        config = {
            'lr': float(K.get_value(self.lr)),
            'beta_1': float(K.get_value(self.beta_1)),
            'beta_2': float(K.get_value(self.beta_2)),
            'decay': float(K.get_value(self.decay)),
            'epsilon': self.epsilon,
            'amsgrad': self.amsgrad
        }
        base_config = super(SWATS, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Заранее прошу прощения, если я оскорбил любой этикет вопросов в этом сообществе.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...