Передача вывода промежуточного слоя в функцию лямбда в керасе / тензорном потоке - PullRequest
0 голосов
/ 08 октября 2018

Моя проблема заключается в следующем: у меня есть модель Keras с функцией потерь, реализованной как лямбда-слой.В этой функции потерь я хочу использовать выходные данные промежуточного слоя модели.То, что я делал, выглядит следующим образом:

.....model_init definition.....
layer_output = model_init.get_layer('conv2d_5').output
model_loss = Lambda(loss_function, output_shape=(1,), name='loss_function',
        arguments={'layer': layer_output})(
        [*model_init.output, *y_true])
model = Model([model_init.input, *y_true], model_loss)

def loss_function(args, layer):
    #do stuff with layer

    return loss

Здесь и model_init.output, и y_true являются списками из 3 слоев.Я заметил проблему при использовании нескольких графических процессоров, когда слой передается как целый пакет, а не как отдельные образцы для конкретного графического процессора.Поэтому я изменил свой код следующим образом:

.....model_init definition.....
layer_output = model_init.get_layer('conv2d_5').output
model_loss = Lambda(loss_function, output_shape=(1,), name='loss_function',
        arguments={'layer': layer_output})(
        [*model_init.output, *y_true, *[layer_output, layer_output, layer_output]])
model = Model([model_init.input, *y_true], model_loss)

def loss_function(args, layer):
    intermediate_layer = args[6]
    intermediate_layer = tf.Print(intermediate_layer, [intermediate_layer, layer], message="comparing layer values")
    #do stuff with intermediate_layer
    return loss

Это помогло решить проблему, связанную с обработкой нескольких графических процессоров.Однако теперь, когда я смотрю на значения промежуточного уровня, когда они передаются в функцию потерь, я замечаю, что они не совпадают, когда я передаю его в качестве аргумента функции (как в первом примере кода) и когда янапрямую подключите его к слою лямбда (как во втором примере).Поэтому мне было интересно, что здесь происходит и какой путь должен быть правильным.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...