Tensorflow Keras изменяют переменную модели из обратного вызова - PullRequest
2 голосов
/ 06 февраля 2020

Я пытаюсь изменить необучаемую переменную модели из обратного вызова в начале каждой эпохи. По сути, я хотел бы иметь механизм, аналогичный планировщику скорости обучения (который имеет встроенную инфраструктуру в TF), но применимый к произвольной переменной модели. Код ниже является минимальным примером, чтобы показать концепцию. Я пытаюсь изменить переменную decay , но она не работает. По-видимому, начальное значение переменной (1.0) рассматривается как константа и складывается графиком и никогда не рассматривается снова в процессе обучения, даже если переменная, по-видимому, должным образом изменена (до 0,5) обратным вызовом.

dense1 = tf.keras.layers.Dense(10)
decay = tf.Variable(1.0, trainable=False)
dense2 = tf.keras.layers.Dense(10)

def epoch_callback(epoch):
    nonlocal decay
    tf.keras.backend.set_value(decay, 0.5)
    #decay.assign(0.5)
    print(tf.keras.backend.get_value(decay))

input = tf.keras.layers.Input((MAX_LENGTH,))
x = dense1(input)

with tf.control_dependencies([decay]):
    x = x * decay

prediction = dense2(x)

model = tf.keras.Model(inputs=[input], outputs=[prediction])
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

callbacks = [tf.keras.callbacks.LambdaCallback(on_epoch_begin = lambda epoch, logs: epoch_callback(epoch))]

model.fit(train_ds, epochs=EPOCHS, verbose=1, callbacks=callbacks, validation_data=eval_ds)

@ nbro: Здесь вы go. Код ниже, что работает для меня. Я использую протокол принуждения учителя, а переменная затухания для каждой эпохи используется, чтобы «понизить голос учителя» в процессе обучения.

class Teacher(tf.keras.layers.Layer):
    def __init__(self, embedding, name='teacher', **kwargs):
        super().__init__(name=name, **kwargs)
        ...

    def build(self, input_shape):
        ...

    def call(self, inputs, training=None):
        x, y, decay = inputs
        ...
        if training:
            y = tf.multiply(y, decay)
        else:
            y = tf.multiply(y, tf.constant(0.0))
        ...
        return x

    def get_config(self):
        return {}

class MyNet(tf.keras.Model):
    def __init__(self, name='mynet', **kwargs):
        super().__init__(name=name, **kwargs)

    def build(self, input_shape):
        ...
        self.teacher = Teacher()
        self.decay = tf.Variable(1.0, trainable=False)
        ...

    def set_decay(self, decay):
        self.decay.assign(decay)

    @tf.function
    def call(self, example, training=None):
        x, y = example
        ...
        x = self.teacher((x, y, self.decay))
        ...
        return x

    def get_config(self):
        return {}

def main():

    train_ds = ...
    eval_ds = ...

    train_ds = train_ds.map(lambda data, label: ((data, label), label), num_parallel_calls=tf.data.experimental.AUTOTUNE)
    eval_ds = eval_ds.map(lambda data, label: ((data, label), label), num_parallel_calls=tf.data.experimental.AUTOTUNE)


    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        the_net = MyNet()
        inputs = tf.keras.layers.Input((MAX_LENGTH,), dtype='int64', name='inputs')
        targets = tf.keras.layers.Input((MAX_LENGTH,), dtype='int64', name='targets')
        prediction = the_net((inputs, targets))
        model = tf.keras.Model(inputs=[inputs, targets], outputs=[prediction])
        model.compile(optimizer=tf.keras.optimizers.Adam(), loss=CosineSimilarity(name='val_loss'))

    def _callback_fun(epoch, start = 0, steps = 8):
        the_net.set_decay(tf.clip_by_value((start+steps-epoch)/steps, clip_value_min=tf.constant(0.0), clip_value_max=tf.constant(1.0)))

    callbacks = [tf.keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: _callback_fun(epoch))]

    model.fit(train_ds, epochs=EPOCHS, verbose=2, callbacks=callbacks, validation_data=eval_ds)

if __name__ == '__main__':
    main()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...