Keras: исчезновение слоя или изменение переменной в течение сеанса - PullRequest
0 голосов
/ 06 апреля 2019

Я пытаюсь постепенно растушить слой кераса на несколько партий. Поэтому я написал пользовательский слой "DecayingSkip". Кроме того, я добавляю его на другой слой. Я пытаюсь реализовать исчезающее соединение пропустить. Однако код, кажется, не работает правильно. Модель компилируется и обучается, но активация слоя не исчезла, как ожидалось. Что я делаю не так?

class DecayingSkip(Layer):
    def __init__(self, fade_out_at_batch, **kwargs):
        self.fade_out_at_batch = K.variable(fade_out_at_batch)
        self.btch_cnt = K.variable(0)
        super(decayingSkip, self).__init__(**kwargs)

    def call(self, x):
        self.btch_cnt = self.btch_cnt + 1.0
        return K.switch(
            self.btch_cnt >= self.fade_out_at_batch,
            x * 0,
            x *  (1.0 - ((1.0 / self.fade_out_at_batch) * self.btch_cnt))
        )

def add_fade_out(fadeOutLayer, layer, fade_out_at_batch):
    cnn_match = Conv2D(filters=int(layer.shape[-1]), kernel_size=1, activation=bounded_relu)(fadeOutLayer)
    fadeOutLayer = DecayingSkip(fade_out_at_batch=fade_out_at_batch, name=name + '_fade_out')(cnn_match)
    return Add()([fadeOutLayer, layer])

Кроме того, в другой попытке я попытался использовать переменную tenorflow, которую я изменил в сеансе, например:

def add_fade_out(fadeOutLayer, layer):
    fadeOutLayer = Conv2D(filters=int(layer.shape[-1]), kernel_size=1, activation='relu')(fadeOutLayer)
    alph = K.variable(1.0, name='alpha')
    fadeOutLayer = Lambda(lambda x: x * alph)(fadeOutLayer)
    return Add()([fadeOutLayer, layer])

sess = K.get_session()
lw = sess.graph.get_tensor_by_name("alpha:0") 
sess.run(K.tf.assign(lw, new_value))

Это тоже не сработало. Почему?

1 Ответ

0 голосов
/ 07 апреля 2019

Я думаю, что нашел решение.Я изменил функцию вызова слоя на:

def call(self, x):
    self.btch_cnt = K.tf.assign_add(self.btch_cnt, 1)
    K.get_session().run(self.btch_cnt)

    return K.switch(
        self.btch_cnt >= self.fade_out_at_batch,
        x * 0,
        x * (1.0 - ((1.0 / self.fade_out_at_batch) * self.btch_cnt))
    )
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...