Я пытаюсь реализовать адаптацию домена с помощью состязательного обучения в Keras. Я действительно заставил его работать с помощью бэкэнда TensorFlow, но, к сожалению, мне нужно реализовать его с помощью бэкэнда MX- net, чтобы быть совместимым с моими коллегами. Ключевой частью является определение слоя обращения градиента, который передает входные данные в качестве идентификатора для прямого прохода и умножает градиенты на отрицательный скаляр ( hp_lambda ) на обратном проходе. Для версии TensorFlow я использовал то, что нашел: https://github.com/michetonu/gradient_reversal_keras_tf/blob/master/flipGradientTF.py, который я опубликовал ниже:
import tensor flow as tf
from keras.engine import Layer
import keras.backend as K
def reverse_gradient(X, hp_lambda):
@tf.RegisterGradient(grad_name)
def _flip_gradients(op, grad):
return [tf.negative(grad) * hp_lambda]
g = K.get_session().graph
with g.gradient_override_map({'Identity': grad_name}):
y = tf.identity(X)
return y
class GradientReversal(Layer):
def __init__(self, hp_lambda, **kwargs):
super(GradientReversal, self).__init__(**kwargs)
self.supports_masking = False
self.hp_lambda = K.variable(hp_lambda)
def build(self, input_shape):
self.trainable_weights = []
def call(self, x, mask=None):
return reverse_gradient(x, self.hp_lambda)
def get_output_shape_for(self, input_shape):
return input_shape
def get_config(self):
config = {}
base_config = super(GradientReversal, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Я так понимаю, что могу оставить класс GradientReversal в основном неизменным и переопределить функцию reverse_gradient () для MX- Net, а не для Tensorflow, но мне не удалось найти MX- Net, эквивалентную чему-то вроде tf.RegisterGradient .
Кто-нибудь реализовал что-то подобное или иным образом знает, как обрабатывать настраиваемое поведение градиента в Keras с использованием бэкэнда MX- Net?
Спасибо!