Я построил пару пользовательских слоев и модель в Tensorflow2.0a, расширяя классы Keras. Во всех них я реализовал методы get_config()
и from_config()
. Я также реализовал метод model.save()
в классе CustomModel (аналогично последовательной модели).
def save(self, filepath, overwrite=True, include_optimizer=True, **kwargs):
from tensorflow.python.keras.models import save_model # pylint: disable=g-import-not-at-top
save_model(self, filepath, overwrite, include_optimizer)
Эти методы позволяют мне использовать:
model.save("model.h5")
new_model = keras.models.load_model('model.h5',
custom_objects={
'CustomModel': CustomModel,
'CustomLayer': CustomLayer, #...
})
Для загрузки модели с 8 слоями и 9000 параметрами требуется около 30 секунд.
Но я также могу загрузить модель следующим образом:
model.save_weights("weights.h5")
config = model.get_config()
new_model = CustomModel.from_config(config)
new_model.load_weights("weights.h5")
С другой стороны, этот метод занимает 0,6 с для той же модели.
Есть ли причина, по которой первый метод такой медленный? Это связано с десериализацией пользовательской модели / слоев?
В таком случае какой метод должен быть предпочтительным?
Я знаю, что TF2 довольно новый, но любая информация будет принята с благодарностью!