Используя tf.keras в TF 2.0, как я могу определить пользовательский слой, который зависит от фазы обучения? - PullRequest
0 голосов
/ 14 февраля 2019

Я хочу создать пользовательский слой, используя tf.keras.Для простоты предположим, что он должен возвращать входные данные * 2 во время обучения и входные данные * 3 во время тестирования.Как правильно это сделать?

Я попробовал этот подход:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training:
            return inputs*2
        else:
            return inputs*3

Затем я могу использовать этот класс следующим образом:

>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)

Он отлично работает!Однако, когда я использую этот класс в модели и вызываю его метод fit(), кажется, что training не установлен в True.Я попытался добавить следующий код в начале метода call(), но training всегда равен 0.

if training is None:
    training = K.learning_phase()

Чего мне не хватает?

Edit

Я нашел решение (см. Мой ответ), но я все еще ищу более подходящее решение с использованием @tf.function (я предпочитаю автограф этому smart_cond() бизнесу).К сожалению, похоже, что K.learning_phase() не очень хорошо работает с @tf.function (я полагаю, что когда функция call() отслеживается, этап обучения жестко запрограммирован в графике: так как это происходит до вызова fit() метод, фаза обучения всегда 0).Это может быть ошибка, или, возможно, есть другой способ получить фазу обучения при использовании @tf.function.

Ответы [ 2 ]

0 голосов
/ 15 февраля 2019

Франсуа Шоле подтвердил, что правильное решение при использовании @tf.function:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        if training:
            return inputs * 2
        else:
            return inputs * 3

В настоящее время существует ошибка (по состоянию на 15 февраля 2019 г.), из-за которой training всегда равно 0, ноэто будет исправлено в ближайшее время.

0 голосов
/ 14 февраля 2019

В следующем коде не используется @tf.function, поэтому он выглядит не так хорошо (поскольку не использует автограф), но работает нормально:

from tensorflow.python.keras.utils.tf_utils import smart_cond

class CustomLayer(Layer):
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        return smart_cond(training, lambda: inputs * 2, lambda: inputs * 3)
...