Возврат с замороженным pb-файлом на примере MNIST - PullRequest
0 голосов
/ 03 декабря 2018

У меня есть обучающий скрипт MNIST для python tenorflow, который может генерировать замороженный файл * .pb для вывода.

import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework.graph_util import convert_variables_to_constants
import utils

epochs = 250
batch_size = 55000 # Entire training set

# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batches = int(len(mnist.train.images) / batch_size)

# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784], name='image')
label = tf.placeholder(tf.float32, [None, 10], name='label')

# Define the model
layer1 = layers.fully_connected(image, 300)
layer2 = layers.fully_connected(layer1, 300)
logits = layers.fully_connected(layer2, 10)

# Create global step variable (needed for pruning)
global_step = tf.train.get_or_create_global_step()
reset_global_step_op = tf.assign(global_step, 0)

# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))

# running this operation increments the global_step
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step, name='train_op')

# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')

# Create a saver for writing training checkpoints.
saver = tf.train.Saver()

with tf.Session() as sess:
    # Uncomment the following if you don't have a trained model yet
    sess.run(tf.initialize_all_variables())

    # Train the model before pruning (optional)
    for epoch in range(epochs):
        for batch in range(batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})

        # Calculate Test Accuracy every 10 epochs
        if epoch % 10 == 0:
            acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
            print("Normal Train Model step %d test accuracy %g" % (epoch, acc_print))

    acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
    print("Normal Train Model accuracy:", acc_print)

    # Save to full converted pb file
    graph = convert_variables_to_constants(sess, sess.graph_def, ["accuracy"])
    with tf.gfile.FastGFile('normal_train2.pb', mode='wb') as f:
        f.write(graph.SerializeToString())

И что я могу сделать, это использовать * .pb в качестве логического вывода, чтобы я мог оценить некоторые изображения.

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import utils
import sys

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def load_model(path_to_model):
    if not os.path.exists(path_to_model):
        raise ValueError("'path_to_model.pb' is not exist.")

    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(path_to_model, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return model_graph

def main(path_to_model):    
    print("path_to_model:", path_to_model)
    model_graph = load_model(path_to_model)

    accuracy = model_graph.get_tensor_by_name('accuracy:0')
    image = model_graph.get_tensor_by_name('image:0')
    label = model_graph.get_tensor_by_name('label:0')

    equal = model_graph.get_tensor_by_name('Equal:0')
    logits = model_graph.get_tensor_by_name('ArgMax:0')

    with tf.Session(graph=model_graph) as sess:
        acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
        print("eval_from_pb accuracy:", acc_print)

if __name__ == "__main__":
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    main(*sys.argv[1:])    

Также с просмотром * .pbtxt и печатьюЯ знаю, что могу использовать logits = model_graph.get_tensor_by_name('ArgMax:0'), чтобы на самом деле напечатать метку предсказания изображения.

Я пытаюсь спросить: возможно ли сделать файл train_op из * .pb? Мне не нужно непрерывное обучение , я просто хочу знать, как восстановить train_op с файлом * .pb, чтобы я мог тренироваться с самого начала.

У меня естьпрочитайте оба Переподготовьте замороженную модель * .pb в TensorFlow и Клонирование сети с помощью tf.contrib.graph_editor , оба они не являются примерами, я пробовал с ними,но не везет.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...