Я использовал OpenAI для обучения модели DeepQ. После выполнения
saver = tf.train.Saver()
saver.save(tf.get_default_session(), 'my_deepq')
Я получил следующие файлы:
my_deepq.data-00000-of-00001
my_deepq.index
checkpoint
my_deepq.meta
Затем мне нужно загрузить эту модель в две разные системы (C ++ и python), чтобы сделать вывод.
Для части Python я попытался:
import tensorflow as tf
tf.reset_default_graph()
imported_graph = tf.train.import_meta_graph('my_deepq.meta')
with tf.Session() as sess:
imported_graph.restore(sess, './my_deepq')
Коды работали, но я не уверен, где была загружена модель и как сделать вывод. Может кто-нибудь, пожалуйста, посоветуйте.
Для стороны C ++ я сделаю что-то вроде:
tensorflow::Session *my_sess;
tensorflow::Status status = tensorflow::NewSession(options, &my_sess);
tensorflow::GraphDef graph_def;
status = ReadBinaryProto(tensorflow::Env::Default(), model_path, &graph_def);
status = my_sess->Create(graph_def);
tensorflow::Status status = my_sess->Run({{"My_Input", input_tensor}}, {"My_Output"}, {}, &output_tensor);
Этот подход требует, чтобы модель была в формате BinaryProto, но я не уверен, как сохранить мою модель в BinaryProto на python. Может кто-нибудь, пожалуйста, посоветуйте. Спасибо!