Вы можете использовать класс ниже, созданный из tf.estimator.BestExporter
Что он делает, за исключением сохранения лучшей модели (.pb файлы и et c) он также сохранит контрольную точку наиболее экспортируемой модели в другой папке.
Ниже приведен класс:
import shutil, glob, os
# import tensorflow.logging as logging
## the path where all the checkpoint reside
BEST_CHECKPOINTS_PATH_FROM = 'PATH TO ALL CHECKPOINT FILES'
## the path it will save the best exporter checkpoint files
BEST_CHECKPOINTS_PATH_TO = 'PATH TO BEST EXPORTER CHECKPOINT FILES TO BE SAVE'
class BestCheckpointsExporter(tf.estimator.BestExporter):
def export(self, estimator, export_path, checkpoint_path, eval_result,is_the_final_export):
if self._best_eval_result is None or \
self._compare_fn(self._best_eval_result, eval_result):
#print('Exporting a better model ({} instead of {})...'.format(eval_result, self._best_eval_result))
for name in glob.glob(checkpoint_path + '.*'):
print(name)
print(os.path.join(BEST_CHECKPOINTS_PATH_TO, os.path.basename(name)))
shutil.copy(name, os.path.join(BEST_CHECKPOINTS_PATH_TO, os.path.basename(name)))
# also save the text file used by the estimator api to find the best checkpoint
with open(os.path.join(BEST_CHECKPOINTS_PATH_TO, "checkpoint"), 'w') as f:
f.write("model_checkpoint_path: \"{}\"".format(os.path.basename(checkpoint_path)))
self._best_eval_result = eval_result
else:
print('Keeping the current best model ({} instead of {}).'.format(self._best_eval_result, eval_result))
Пример использования класса Вы просто замените экспортер, вызвав класс и передавая serve_input_receiver_fn.
def serving_input_receiver_fn():
inputs = {'my_dense_input': tf.compat.v1.placeholder(shape=[None, 4], dtype=tf.float32)}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
exporter = BestCheckpointsExporter(serving_input_receiver_fn=serving_input_receiver_fn)
train_spec_dnn = tf.estimator.TrainSpec(input_fn = input_fn, max_steps=5)
eval_spec_dnn = tf.estimator.EvalSpec(input_fn=input_fn,exporters=exporter,start_delay_secs=0,throttle_secs=15)
(x, y) = tf.estimator.train_and_evaluate(keras_estimator, train_spec_dnn, eval_spec_dnn)
На этом этапе он сохранит файлы контрольных точек наиболее экспортируемых моделей в указанной вами папке.
Для загрузки файлов контрольных точек вам необходимо выполнить следующие шаги:
Шаг 1: Перестроить экземпляр вашей модели
def build_model():
model = tf.keras.models.Sequential()
model.add(...)
model.compile(...)
return model
model = build_model()
Шаг 2: использовать модель load_weights API Ссылочный URL: https://www.tensorflow.org/tutorials/keras/save_and_load
ck_path = tf.train.latest_checkpoint('PATH TO BEST EXPORTER CHECKPOINT FILES')
model.load_weights(ck_path)
## From there you will be able to call the predict & evaluate the functionality of the trained model
##PREDICT
prediction = model.predict(x)
##EVALUATE
for features_batch, labels_batch in input_fn().take(1):
model.evaluate(features_batch, labels_batch)
Примечание. Все они были смоделированы в Google Colab.