Я использовал пользовательскую модель, доступную на сайте TF, чтобы продемонстрировать эту идею. Использование подклассовых моделей мало отличается от последовательной и функциональной модели Keras. Я использовал модель Subclassed для демонстрации идеи следующим образом.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
базовая модель
class ThreeLayerMLP(keras.Model):
def __init__(self, name=None):
super(ThreeLayerMLP, self).__init__(name=name)
self.dense_1 = layers.Dense(64, activation='relu', name='dense_1')
self.dense_2 = layers.Dense(64, activation='relu', name='dense_2')
self.pred_layer = layers.Dense(10, name='predictions')
def call(self, inputs):
x = self.dense_1(inputs)
x = self.dense_2(x)
return self.pred_layer(x)
def get_model():
return ThreeLayerMLP(name='3_layer_mlp')
base_model = get_model()
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
base_model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop())
history = base_model.fit(x_train, y_train,
batch_size=64,
epochs=1)
#saving weights
base_model.save_weights('./base_model_weights', save_format='tf')
пользовательская модель
class MyCustomModel(keras.Model):
def __init__(self, name=None):
super(MyCustomModel, self).__init__(name=name)
self.dense_1 = layers.Dense(64, activation='relu', name='dense_1')
self.dense_2 = layers.Dense(64, activation='relu', name='dense_2')
self.pred_layer = layers.Dense(10, name='predictions')
def call(self, inputs):
x = self.dense_1(inputs)
x = self.dense_2(x)
return self.pred_layer(x)
def get_custom_model():
return MyCustomModel(name='my_custom_model')
my_custom_model = get_custom_model()
my_custom_model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop())
Следующим шагом является импорт, если в сохраненной модели есть какие-либо пользовательские объекты или пользовательские слои в модели
# This initializes the variables used by the optimizers,
# as well as any stateful metric variables
my_custom_model.train_on_batch(x_train[:1], y_train[:1])
# Load the state of the old model (to load weights for all layers)
# my_custom_model.load_weights('path_to_my_weights')
layer_dict = dict([(layer.name, layer) for layer in base_model.layers])
print(layer_dict)
# my_custom_model.trainable = True
# loading the weights from base_model
for layer in my_custom_model.layers:
layer_name = layer.name
#print(layer.name)
layer.set_weights(layer_dict[layer_name].get_weights())