Сохраните, прочитайте и выполните вывод в C ++ / Python для обученной модели OpenAI - PullRequest
1 голос
/ 11 мая 2019

Я использовал OpenAI для обучения модели DeepQ. После выполнения

saver = tf.train.Saver()
saver.save(tf.get_default_session(), 'my_deepq')

Я получил следующие файлы:

my_deepq.data-00000-of-00001
my_deepq.index
checkpoint
my_deepq.meta

Затем мне нужно загрузить эту модель в две разные системы (C ++ и python), чтобы сделать вывод.

Для части Python я попытался:

import tensorflow as tf
tf.reset_default_graph()
imported_graph = tf.train.import_meta_graph('my_deepq.meta')
with tf.Session() as sess:
    imported_graph.restore(sess, './my_deepq')

Коды работали, но я не уверен, где была загружена модель и как сделать вывод. Может кто-нибудь, пожалуйста, посоветуйте.


Для стороны C ++ я сделаю что-то вроде:

tensorflow::Session *my_sess;
tensorflow::Status status = tensorflow::NewSession(options, &my_sess);
tensorflow::GraphDef graph_def;
status = ReadBinaryProto(tensorflow::Env::Default(), model_path, &graph_def);

status = my_sess->Create(graph_def);
tensorflow::Status status = my_sess->Run({{"My_Input", input_tensor}}, {"My_Output"}, {}, &output_tensor);

Этот подход требует, чтобы модель была в формате BinaryProto, но я не уверен, как сохранить мою модель в BinaryProto на python. Может кто-нибудь, пожалуйста, посоветуйте. Спасибо!

...