Не зная деталей вашей модели, следующий фрагмент может помочь:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
# Train your initial model
def get_initial_model():
...
return model
model = get_initial_model()
model.fit(...)
model.save_weights('initial_model_weights.h5')
# Use Model API to create another model, built on your initial model
initial_model = get_initial_model()
initial_model.load_weights('initial_model_weights.h5')
nn_input = Input(...)
x = initial_model(nn_input)
x = Dense(...)(x) # This is the additional layer, connected to your initial model
nn_output = Dense(...)(x)
# Combine your model
full_model = Model(inputs=nn_input, outputs=nn_output)
# Compile and train as usual
full_model.compile(...)
full_model.fit(...)
По сути, вы тренируете свою первоначальную модель, сохраняете ее. И перезагрузите его снова, и оберните его вместе с вашими дополнительными слоями, используя Model
API. Если вы не знакомы с Model
API, вы можете проверить документацию Keras здесь (на самом деле API остается неизменным для Tensorflow.Keras 2.0).
Обратите внимание, что вам нужно проверьте, совместима ли выходная форма конечного слоя вашей исходной модели с дополнительными слоями (например, вы можете удалить конечный плотный слой из вашей исходной модели, если вы просто извлекаете объект).