не могу загрузить модель keras с неизвестной пользовательской функцией потери лямбды - PullRequest
0 голосов
/ 19 апреля 2020

моя модель выглядит как

Функции множественных потерь с одним выходом в Керасе: { ссылка }

model = Model(inputs=[sketch_inp, color_inp], outputs=disc_outputs)

opt = Adam(lr=learning_rate, beta_1=.5)



model.compile(loss=lambda y_true, y_pred : tf.keras.losses.binary_crossentropy(y_true, y_pred) + \
                                                 pixelLevelLoss_weight * pixelLevelLoss(y_true, y_pred) + \
                                                 totalVariationLoss_weight * totalVariationLoss(y_true, y_pred) + \
                                                 featureLevelLoss_weight * featureLevelLoss(y_true, y_pred),\
                    optimizer=opt)

После сохранения модели я хочу загрузить ее и завершить обучение, но я не знаю, как загрузить его с помощью этой пользовательской функции потери

1 Ответ

0 голосов
/ 19 апреля 2020

Во время загрузки модели просто используйте аргумент cutom_objects для передачи потери.

Если модель, которую вы хотите загрузить, включает в себя пользовательские слои или другие пользовательские классы или функции, вы можете передать их в механизм загрузки через Аргумент custom_objects:

from keras.models import load_model
# Assuming your model includes instance of an "AttentionLayer" class
model = load_model('my_model.h5', custom_objects={'AttentionLayer': AttentionLayer})
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...