Как реализовать обратный градиент слоя в TF 2.0? - PullRequest
2 голосов
/ 01 июля 2019

Этот слой статический, это псевдо-функция. В прямом распространении он ничего не делает (функция идентичности). Однако при обратном распространении градиент умножается на -1. Существует множество реализаций на github, но они не работают с TF 2.0.

Вот один для справки.

import tensorflow as tf
from tensorflow.python.framework import ops

class FlipGradientBuilder(object):
    def __init__(self):
        self.num_calls = 0

    def __call__(self, x, l=1.0):
        grad_name = "FlipGradient%d" % self.num_calls
        @ops.RegisterGradient(grad_name)
        def _flip_gradients(op, grad):
            return [tf.negative(grad) * l]

        g = tf.get_default_graph()
        with g.gradient_override_map({"Identity": grad_name}):
            y = tf.identity(x)

        self.num_calls += 1
        return y

flip_gradient = FlipGradientBuilder()

1 Ответ

2 голосов
/ 02 июля 2019

Dummy op, который меняет градиенты

Это можно сделать с помощью декоратора tf.custom_gradient, как описано в этом примере :

@tf.custom_gradient
def grad_reverse(x):
    y = tf.identity(x)
    def custom_grad(dy):
        return -dy
    return result, custom_grad

Затем выможно просто использовать его, как если бы он был обычным оператором TensorFlow, например:

z = encoder(x)
r = grad_reverse(z)
y = decoder(r)

Keras API?

Большим удобством TF 2.0 является его собственная поддержка Keras API.Вы можете определить пользовательский GradReverse op и наслаждаться удобством Keras:

class GradReverse(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, x):
        return grad_reverse(x)

Затем вы можете использовать этот слой как любые другие слои Keras, например:

model = Sequential()
conv = tf.keras.layers.Conv2D(...)(inp)
cust = CustomLayer()(conv)
flat = tf.keras.layers.Flatten()(cust)
fc = tf.keras.layers.Dense(num_classes)(flat)

model = tf.keras.models.Model(inputs=[inp], outputs=[fc])
model.compile(loss=..., optimizer=...)
model.fit(...)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...