Реализовать слой обратного градиента в Keras с бэкендом MX- Net? - PullRequest
0 голосов
/ 11 июля 2020

Я пытаюсь реализовать адаптацию домена с помощью состязательного обучения в Keras. Я действительно заставил его работать с помощью бэкэнда TensorFlow, но, к сожалению, мне нужно реализовать его с помощью бэкэнда MX- net, чтобы быть совместимым с моими коллегами. Ключевой частью является определение слоя обращения градиента, который передает входные данные в качестве идентификатора для прямого прохода и умножает градиенты на отрицательный скаляр ( hp_lambda ) на обратном проходе. Для версии TensorFlow я использовал то, что нашел: https://github.com/michetonu/gradient_reversal_keras_tf/blob/master/flipGradientTF.py, который я опубликовал ниже:

import tensor flow as tf
from keras.engine import Layer
import keras.backend as K

def reverse_gradient(X, hp_lambda):

    @tf.RegisterGradient(grad_name)
    def _flip_gradients(op, grad):
        return [tf.negative(grad) * hp_lambda]

    g = K.get_session().graph
    with g.gradient_override_map({'Identity': grad_name}):
        y = tf.identity(X)

    return y


class GradientReversal(Layer):
    def __init__(self, hp_lambda, **kwargs):
        super(GradientReversal, self).__init__(**kwargs)
        self.supports_masking = False    
        self.hp_lambda = K.variable(hp_lambda)

    def build(self, input_shape):
        self.trainable_weights = []

    def call(self, x, mask=None):
        return reverse_gradient(x, self.hp_lambda)

    def get_output_shape_for(self, input_shape):
        return input_shape

    def get_config(self):
        config = {}
        base_config = super(GradientReversal, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Я так понимаю, что могу оставить класс GradientReversal в основном неизменным и переопределить функцию reverse_gradient () для MX- Net, а не для Tensorflow, но мне не удалось найти MX- Net, эквивалентную чему-то вроде tf.RegisterGradient .

Кто-нибудь реализовал что-то подобное или иным образом знает, как обрабатывать настраиваемое поведение градиента в Keras с использованием бэкэнда MX- Net?

Спасибо!

...