Как управлять правилом дифференциальной цепи в Керасе - PullRequest
1 голос
/ 03 марта 2020

У меня сверточная нейронная сеть с несколькими слоями в керасах. Последний слой в этой сети - это пользовательский слой, который отвечает за сортировку некоторых чисел, полученных этим слоем с предыдущего уровня, затем выходные данные пользовательского слоя отправляются для функции расчета потерь.

для этой цели (сортировка) я использую некоторые операторы в этом слое, такие как K.argmax и K.gather.

На этапе обратного распространения я получаю ошибку от keras, которая говорит:

Операция имеет None для градиента. Пожалуйста, убедитесь, что все ваши операции имеют определенный градиент (то есть являются дифференцируемыми). Обычные операции без градиента: K.argmax, K.round, K.eval

, что является разумным , вызывает участие этого слоя в процессе деривации.

Учитывая, что мой пользовательский уровень не нуждается в корпоративном правиле дифференциальной цепочки, как я могу контролировать дифференциальную цепочку в кератах? я могу отключить этот процесс в пользовательском слое?

Порядок переупорядочения, который я использовал в своем коде, просто следующий:

def Reorder(args):
    z = args[0]
    l = args[1]
    index = K.tf.argmax(l, axis=1)
    return K.tf.gather(z, index)

Reorder_Layer = Lambda(Reorder, name='out_x')
pred_x = Reorder_Layer([z, op])

1 Ответ

1 голос
/ 03 марта 2020

Несколько вещей:

  • Невозможно тренироваться без деривата, поэтому нет решения, если вы хотите тренировать эту модель
  • Нет необходимости "компилировать" если вы только собираетесь прогнозировать, поэтому вам не нужны пользовательские правила деривации

Если проблема действительно в этом слое, я предполагаю, что l вычисляется моделью с использованием обучаемых слоев до Это.

Если вы действительно хотите попробовать это, что не очень хорошая идея, вы можете попробовать l = keras.backend.stop_gradient(args[1]). Но это означает, что абсолютно ничто не будет обучено с l до начала модели. Если это не работает, то вы должны сделать все слои, которые производят l, имеют trainable=False до компиляции модели.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...