как правильно загрузить модель с кастомным слоем? - PullRequest
0 голосов
/ 17 июня 2020

Я пытаюсь загрузить модель с пользовательским слоем:

siamese_model = load_model(path, custom_objects={'siamese_loss': SIAMESE_LOSS})

С переданным словарем модель должна быть успешно загружена, но ошибка все равно выскочила:

ValueError: Unknown layer: SIAMESE_LOSS

Код для настраиваемого слоя:

class SIAMESE_LOSS(Layer):
    def __init__(self, **kwargs):
        super(SIAMESE_LOSS, self).__init__(**kwargs)

    @staticmethod
    def mmd_loss(source_samples, target_samples):
        return mmd(source_samples, target_samples)

    @staticmethod
    def regression_loss(pred, labels):
        return K.mean(mae(pred, labels))

    @staticmethod
    def regression_mse(pred, labels):
        return K.mean(mse(pred, labels))

    def call(self, inputs, **kwargs):
        source_labels = inputs[0]
        target_labels = inputs[1]
        source_pred = inputs[2]
        target_pred = inputs[3]
        source_samples = inputs[4]
        target_samples = inputs[5]

        source_loss = self.regression_loss(source_pred, source_labels)
        target_loss = self.regression_loss(target_pred, target_labels)
        mmd_loss = self.mmd_loss(source_samples, target_samples)
        total_loss = source_loss + target_loss + mmd_loss

        source_mse = self.regression_mse(source_pred, source_labels)
        target_mse = self.regression_mse(target_pred, target_labels)

        self.add_loss(total_loss, inputs=True)
        self.add_metric(target_loss, aggregation='mean', name='target_mae')
        self.add_metric(source_loss, aggregation='mean', name='source_mae')
        self.add_metric(mmd_loss, aggregation='mean', name='MMD')
        self.add_metric(target_mse, aggregation='mean', name='target_mse')
        self.add_metric(source_mse, aggregation='mean', name='source_mse')
        return inputs[2], inputs[3]

    def get_config(self, **kwargs):
        super(SIAMESE_LOSS, self).get_config(**kwargs)

Есть действительно важная вещь, что я не переписал метод get_config() при обучении модели. Это причина моей проблемы?

...