Восстановление различий в моделях Tensorflow Model - PullRequest
0 голосов
/ 06 февраля 2019

Я видел и пробовал два метода, но не мог понять, в чем разница.Вот два метода, которые я использовал:
Метод 1:

saver = tf.train.import_meta_graph(tf.train.latest_checkpoint(model_path)+".meta")

sess  = tf.Session()
sess.run(tf.global_variables_initializer())   
sess.run(tf.local_variables_initializer()) 
if(tf.train.checkpoint_exists(tf.train.latest_checkpoint(model_path))):
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    print(tf.train.latest_checkpoint(model_path) + "Session Loaded for Testing")    

Метод 2:

saver = tf.train.Saver()
sess  =tf.Session()
sess.run(tf.global_variables_initializer())    
if(tf.train.checkpoint_exists(tf.train.latest_checkpoint(model_path))):
        saver.restore(sess, tf.train.latest_checkpoint(model_path))
        print(tf.train.latest_checkpoint(model_path) + "Session Loaded for Testing")    

Что я хочузнать:

В чем разница между двумя вышеуказанными методами?
Какой метод загрузки модели лучше?

Пожалуйста, дайте мне знать, что вы предлагаете по этому поводу.

1 Ответ

0 голосов
/ 06 февраля 2019

Я постараюсь быть максимально кратким, поэтому вот мои 2 цента по этому вопросу.Я прокомментирую важные строки вашего кода, чтобы указать на то, что я думаю.

# Importing the meta graph is same as building the same graph from scratch
# creating the same variables, creating the same placeholders and ect.
# Basically you are only importing the graph definition
saver = tf.train.import_meta_graph(tf.train.latest_checkpoint(model_path)+".meta")

sess  = tf.Session()
# Absolutely no need to initialize the variables here. They will be initialized
# when the you restore the learned variables.
sess.run(tf.global_variables_initializer())   
sess.run(tf.local_variables_initializer()) 
if(tf.train.checkpoint_exists(tf.train.latest_checkpoint(model_path))):
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    print(tf.train.latest_checkpoint(model_path) + "Session Loaded for Testing")

Что касается второго метода:

# You can't create a saver object like this, you will get an error "No variables to save", which is true.
# You haven't created any variables. The workaround for doing this is:
# saver = tf.train.Saver(defer_build=True) and then after building the graph
# ....Graph building code goes here....
# saver.build()
saver = tf.train.Saver()
sess = tf.Session()
# Absolutely no need to initialize the variables here. They will be initialized
# when the you restore the learned variables. 
sess.run(tf.global_variables_initializer())    
if(tf.train.checkpoint_exists(tf.train.latest_checkpoint(model_path))):
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    print(tf.train.latest_checkpoint(model_path) + "Session Loaded for Testing")

Так что нет ничего плохого в первом подходе, кромевторой - совершенно неправильный.Не поймите меня неправильно, но мне не нравится ни один из них.Впрочем, это всего лишь личный вкус.С другой стороны, я хочу сделать следующее:

# Have a class that creates the model and instantiate an object of that class
my_trained_model = MyModel()
# This is basically the same as what you are doing with
# saver = tf.train.import_meta_graph(tf.train.latest_checkpoint(model_path)+".meta")
# Then, once I have the graph build, I will create a saver object
saver = tf.train.Saver()
# Then I will create a session
with tf.Session() as sess:
    # Restore the trained variables here
    saver.restore(sess, model_checkpoint_path)
    # Now I can do whatever I want with the my_trained_model object

Я надеюсь, что это будет полезно для вас.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...