информация о градиенте журнала keras для tenorboard - PullRequest
0 голосов
/ 20 июня 2020

Я пытаюсь построить гистограмму градиента на тензорной доске, используя тензорный поток и керас.

Я понял, что параметр with_grads для keras.callbacks.tensorboard устарел, и кажется, что для построения графика нужно использовать собственный обратный вызов градиент. Наткнулся на этот пост . Из этого ответа кажется, что необходимые строки кода:

with tf.GradientTape() as tape:
    loss = model(model.trainable_weights) 
tape.gradient(loss, model.trainable_weights)

(может быть, последняя строка тоже должна быть с отступом?), Но ответ на самом деле не указывает, где разместить эти строки, и если обратный вызов действительно необходимо: следует ли создать подкласс от tf.keras.callbacks.TensorBoard и поместить эти строки в функцию on_epoch_end? Что-то вроде

class GradientCalcCallback(keras.callbacks.TensorBoard):
    def __init__(self, model_logs_directory, histogram_freq=1, write_graph=True, write_images=True):
        super(GradientCalcCallback, self).
            __init__(log_dir=model_logs_directory, histogram_freq=histogram_freq,
                     write_graph=True, write_images=write_images, update_freq='epoch')

     def on_epoch_end(self, epoch, logs=None):
        super().on_epoch_end(epoch, logs)
        with tf.GradientTape() as tape:
            loss = self.model(self.model.trainable_weights) 
        tape.gradient(loss, self.model.trainable_weights)

или я могу вызвать их вне обратного вызова, например, перед вызовом model.fit(...)?

Кроме того, при вызове этих строк я получаю сообщение об ошибке размеров тензоры, вызванные строкой loss = model(model.trainable_weights):

Форма тензор (784, 784) несовместима с предоставленной формой (784, 1)

(784, потому что я Я обучаю NN на изображениях MNIST, и изображения имеют размер 28x28 = 784 пикселей

Может ли кто-нибудь помочь разобраться с необходимыми шагами для построения информации о градиенте на тензорной доске при использовании API Keras?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...