Как преобразовать .ckpt в .pb? - PullRequest
1 голос
/ 26 июня 2019

Я новичок в области глубокого обучения и хочу использовать модель с предварительным обучением (EAST) для обслуживания на платформе AI. У меня есть эти файлы, предоставленные разработчиком:

  1. model.ckpt-49491.data-00000-оф-00001
  2. контрольно-пропускной пункт
  3. model.ckpt-49491.index
  4. model.ckpt-49491.meta

Я хочу преобразовать его в формат TensorFlow .pb. Есть ли способ сделать это? Я взял модель от здесь

Полный код доступен здесь .

Я посмотрел здесь , и он показывает следующий код для его преобразования:

С tensorflow/models/research/

INPUT_TYPE=image_tensor
PIPELINE_CONFIG_PATH={path to pipeline config file}
TRAINED_CKPT_PREFIX={path to model.ckpt}
EXPORT_DIR={path to folder that will be used for export}

python object_detection/export_inference_graph.py \
    --input_type=${INPUT_TYPE} \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH} \
    --trained_checkpoint_prefix=${TRAINED_CKPT_PREFIX} \
    --output_directory=${EXPORT_DIR}

Я не могу понять, какое значение передать:

  • input_type
  • PIPELINE_CONFIG_PATH.

Ответы [ 2 ]

0 голосов
/ 22 июля 2019

Следуя ответу @Puneith Kaul, вот синтаксис для tenorflow версии 1.7:

import os
import tensorflow as tf

export_dir = 'export_dir' 
trained_checkpoint_prefix = 'models/model.ckpt'
graph = tf.Graph()
loader = tf.train.import_meta_graph(trained_checkpoint_prefix + ".meta" )
sess = tf.Session()
loader.restore(sess,trained_checkpoint_prefix)
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING, tf.saved_model.tag_constants.SERVING], strip_default_attrs=True)
builder.save()
0 голосов
/ 27 июня 2019

Вот код для преобразования контрольной точки в SavedModel

import os
import tensorflow as tf

trained_checkpoint_prefix = 'models/model.ckpt-49491'
export_dir = os.path.join('export_dir', '0')

graph = tf.Graph()
with tf.compat.v1.Session(graph=graph) as sess:
    # Restore from checkpoint
    loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
    loader.restore(sess, trained_checkpoint_prefix)

    # Export checkpoint to SavedModel
    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
    builder.add_meta_graph_and_variables(sess,
                                         [tf.saved_model.TRAINING, tf.saved_model.SERVING],
                                         strip_default_attrs=True)
    builder.save()                
...