Как получить градиенты в Керасе? - PullRequest
0 голосов
/ 28 сентября 2019

Я хочу вычислить градиенты в каждом слое в Keras, и я реализовал решение, описанное здесь , используя версию keras TensorFlow 2.0, которая на данный момент составляет 2.2.4-tf.

import numpy as np
import tensorflow as tf

from tensorflow.keras.models import Sequential
from tensorflow.keras import backend as K

model = keras.Sequential()
model.add(keras.layers.Dense(20, input_shape = (10, )))
model.add(keras.layers.Dense(5))
model.compile('adam', 'mse')

dummy_in = np.ones((4, 10))
dummy_out = np.ones((4, 5))
dummy_loss = model.train_on_batch(dummy_in, dummy_out)

def get_weight_grad(model, inputs, outputs):
    """ Gets gradient of model for given inputs and outputs for all weights"""
    grads = model.optimizer.get_gradients(model.total_loss, model.trainable_weights)
    symb_inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights)
    f = K.function(symb_inputs, grads)
    x, y, sample_weight = model._standardize_user_data(inputs, outputs)
    output_grad = f(x + y + sample_weight)
    return output_grad


def get_layer_output_grad(model, inputs, outputs, layer=-1):
    """ Gets gradient a layer output for given inputs and outputs"""
    grads = model.optimizer.get_gradients(model.total_loss, model.layers[layer].output)
    symb_inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights)
    f = K.function(symb_inputs, grads)
    x, y, sample_weight = model._standardize_user_data(inputs, outputs)
    output_grad = f(x + y + sample_weight)
    return output_grad


weight_grads = get_weight_grad(model, dummy_in, dummy_out)
output_grad = get_layer_output_grad(model, dummy_in, dummy_out)

Проблема в том, что функция model._standardize_user_data(dummy_in, dummy_out) возвращает NoneType для sample_weight, поэтому я не могу конкатенировать ее с x и y в f(x + y + sample_weight).Я получаю TypeError: can only concatenate list (not "NoneType") to list.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...