получение вывода всех образцов из слоя в пользовательской функции потерь - PullRequest
0 голосов
/ 11 марта 2020

У меня есть модель, которая уже обучена. Я хочу обучить предварительно обученную модель пользовательской функции потерь.

Функция потерь реализует изменение MMD и требует, чтобы все входные выборки оценивались во время обучения поверх существующей функции потерь (MSE). Я попытался написать процедуру для достижения этой цели:

model = load_model("pretrain_model.h5")
intermediate_layer_model = Model(inputs = model.layers[0].input, outputs = model.layers[0].output)

def mse_mme(layer_model, data_src_dom, data_src_trg):
    # Create a loss function that adds the MMD loss to the MSE loss
    def loss(y_true,y_pred):
        mse = K.mean(K.square(y_pred - y_true), axis=-1)

        # is this the right way of using intermediate model?
        # does intermediate model update with model?
        avg_sum_src = K.mean(intermediate_layer_model(data_src_dom), axis=0)
        avg_sum_trg = K.mean(intermediate_layer_model(data_src_trg), axis=0)
        diff = avg_sum_src - avg_sum_trg
        diff_exp = K.expand_dims(diff,0)
        mme = K.dot(diff_exp, K.transpose(diff_exp))

        sum_loss = mse + mme

        return sum_loss

    # Return a function
    return loss

def weight_reg(weight_matrix):
    return K.sum( K.exp( K.abs( weight_matrix*weight_matrix ) ) )

model.layers[0].kernel_regularizer = weight_reg

model.compile(loss=mse_mme(intermediate_layer_model, K.constant(source_data), K.constant(target_data)),
              optimizer='adam')

history = model.fit(X_train, X_train,
                    epochs=1000,
                    batch_size=64,
                    shuffle=True,
                    validation_data=(X_vali, X_vali))

Я не совсем уверен, что этот код правильный. У меня есть две проблемы:

  1. Я подозреваю, что intermediate_layer_model все еще привязан к предварительно обученной модели. Я хочу убедиться, что intermediate_layer_model обновляется каждый раз, когда model обновляется во время тренировки. Как я могу убедиться в этом?
  2. Я не уверен, правильно ли я применяю потери MMD в функции потерь. Формула, которую я пытаюсь реализовать: MMD Loss
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...