Получение градиентов каждого слоя в Керасе 2 - PullRequest
1 голос
/ 05 февраля 2020

Уже несколько дней я боролся за то, чтобы просто просмотреть градиенты слоев в режиме отладки Keras2. Излишне говорить, что я уже пробовал коды, такие как:

import Keras.backend as K
gradients = K.gradients(model.output, model.input)
sess = tf.compat.v1.keras.backend.get_session()
evaluated_gradients = sess.run(gradients, feed_dict={model.input:images})

или

evaluated_gradients = sess.run(gradients, feed_dict{model.input.experimantal_ref():images})

или

with tf.compat.v1.Session(graph=tf.compat.v1.keras.backend.get_default_graph())

или аналогичные подходы с использованием

tf.compat.v1

, что приводит к следующей ошибке:

RuntimeError: Граф сеанса пуст. Добавьте операции в график перед вызовом run ().

Я предполагаю, что это должен быть самый базовый c инструмент, который может предоставить любой пакет глубокого обучения, странно, почему нет простого способа сделать это так в керасе2. Есть идеи?

1 Ответ

1 голос
/ 06 февраля 2020

Вы можете попытаться сделать это на TF 2 с активным режимом.

Обратите внимание, что вам нужно использовать tf.keras для всего, включая вашу модель, слои и т. Д. c. Чтобы это работало, вы никогда не можете использовать keras в одиночку, это должно быть tf.keras. Это означает, например, что tf.keras.layers.Dense, tf.keras.models.Sequential, et c ..

input_images_tensor = tf.constant(input_images_numpy)
with tf.GradientTape() as g:
    g.watch(input_images_tensor)
    output_tensor = model(input_images_tensor)

gradients = g.gradient(output_tensor, input_images_tensor)

Если вы собираетесь вычислять градиенты более одного раза с одной и той же ленты, вы нужно, чтобы лента была persistent=True и удалите ее вручную после получения градиентов. (Подробности см. По ссылке ниже)

Вы можете получить градиенты относительно любого «тренируемого» веса без необходимости watch. Если вы собираетесь получать градиенты по отношению к необучаемым тензорам (таким как входные изображения), то вы должны вызвать g.watch, как указано выше для каждой из этих переменных).

Подробнее о GradientTape: https://www.tensorflow.org/api_docs/python/tf/GradientTape

...