Для ясности рассмотрим два случая.
Случай 1: Простая модель и
Случай 2: Сложная модель, в которой использовались определяемые пользователем классы, унаследованные от tf.keras.Model
.
Случай 1 : Простая модель (как в функциональных и последовательных моделях keras)
Когда вы сохраняете веса модели (используя model.save_weights
), а затем загружаете веса (используя model.load_weights
), по умолчанию метод load_weights
использует топологическую загрузку . Это то же самое для формата Tensorflow saved_model ('tf'), а также для формата 'h5'. Например,
loadedh5_model.load_weights('./MyModel_h5.h5')
# the line above is same as the line below (as second and third arguments are default)
#loadedh5_model.load_weights('./MyModel_h5.h5',by_name=False, skip_mismatch=False)
В случае, если вы хотите загрузить веса указанных c слоев сохраненной модели, вам нужно использовать by_name=True
. Существуют варианты использования, для которых требуется этот тип загрузки.
loadedh5_model.load_weights('./MyModel_h5.h5',by_name=True, skip_mismatch=False)
Случай 2: Сложная модель (как в моделях подкласса Keras)
На данный момент поддерживается только формат tf, только если При создании модели использовались пользовательские классы, унаследованные от tf.keras.Model
.
При загрузке весов из формата TensorFlow поддерживается только топологическая загрузка (by_name = False). Обратите внимание, что топологическая загрузка форматов TensorFlow и HDF5 немного отличается для определяемых пользователем классов, унаследованных от tf.keras.Model: HDF5 загружается на основе сглаженного списка весов, в то время как формат TensorFlow загружается на основе локальных имен объектов атрибутов, для которых слои назначаются в конструкторе модели.
Основная причина в том, что веса имеют формат h5
и формат tf
. Например, рассмотрим Case 1
, где HDF5 загружается на основе уплощенного списка весов. Вес загружается без ошибок. Однако в Case 2
модель имеет user defined classes
, что требует другого подхода, чем просто загрузка плоских грузов. Чтобы позаботиться о назначении весов пользовательских классов, формат 'tf' загружает веса на основе локальных имен объектов атрибутов, которым присваиваются слои в конструкторе модели.
Следующий абзац упоминается в keras веб-сайт, дополнительно поясняет
При загрузке файла веса в формате TensorFlow возвращает тот же объект статуса, что и tf.train.Checkpoint.restore. При построении графа операции восстановления запускаются автоматически, как только сеть построена (при первом вызове для определяемых пользователем классов, наследующих от Model, сразу же, если она уже построена).
Еще один момент для понимания Модели keras Functional
или Sequential
представляют собой статические c графики слоев, которые могут без проблем использовать сплющенные веса. Модель с подклассом Keras (как в нашем случае 2) представляет собой кусок кода Python (метод вызова). Графика слоев нет. Поэтому, как только сеть построена с использованием настраиваемых классов, запускаются операции восстановления для обновления объектов состояния. Надеюсь, поможет.