Загрузка предварительно обученной модели из save_model.pb для трансферного обучения с использованием Custom Estimator - PullRequest
0 голосов
/ 20 июня 2019

Мне было поручено перенести существующую модель из «сырого» тензорного потока.

т)

sess = tf.Session()

sess.run(tf.global_variables_initializer())
train_op = ...
accuracy = ...
for i in range(100):
        print('EPOCH', i)
        _, acc = sess.run([train_op, accuracy], feed_dict={x: inputs, y: labels})

для пользовательского оценщика (на основе общего шаблона для пользовательских оценщиков, который я разработал).

Эта существующая модель использует предварительно обученную модель в сохраненную_модель для своих начальных слоев для извлечения объектов и добавляет несколько слоев + softmax для классификации.

Общая идея:

def create_estimator(config, hyper_params):
    def _model_fn(features, labels, mode):

         <feed features into pretrained model>
         features_layer = <get final layer from pretrained model>

         logits = tf.train.Dense.apply(NUM_CLASSES,
                              activation=None,use_bias=True).
                              apply(feature_layer)
         predictions = tf.nn.softmax(predictions)
         if(mode == Modes.TRAIN or mode == Modes.EVAL):
             loss = ...
             if(Modes == Modes.TRAIN):
                 train_op = 
                    tf.train.AdamOptimizer(
                    learning_rate=learning_rate).minimize(loss)
         if(mode == Modes.PREDICT):
             ...
         return tf.estimator.EstimatorSpec(
             mode=mode,loss=cost,
             train_op=optimizer,
             training_hooks=[...])

Моей первоначальной мыслью было написать класс, производный от tf.train.SessionRunHook, так как я использовал это раньше для частичного восстановления переменных из контрольной точки, и я включил несколько начальных кодов ниже ... но в этот момент я застрял на том, как достичь своей цели.

class RestoreFromSavedModel(tf.train.SessionRunHook):
        def begin():
            graph = tf.get_default_graph()
            with tf.gfile.GFile(os.path.join(
                    MODEL_DIR, IMAGE_CLASSIFIER_FILE), 'rb') as f:
                graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
                _ = tf.import_graph_def(graph_def, name='')
        output_tensor = graph.get_tensor_by_name('prelogits')

Может ли кто-нибудь дать отзыв о том, как я могу прочитать эту сохраненную модель в свой график, передать входные данные в модель и извлечь тензор элемента для ввода в небольшую модель.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...