Как сделать прогнозирование с использованием обученной и сохраненной модели тензорного потока - PullRequest
0 голосов
/ 31 января 2019

У меня есть существующая обученная модель (в частности, tenorflow word2vec https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/5_word2vec.ipynb). Я достаточно хорошо восстанавливаю существующую модель:

model1 = tf.train.import_meta_graph("models/model.meta")
model1.restore(sess, tf.train.latest_checkpoint("model/"))

Но я не знаю, как использовать только что загруженную (иобученная) модель для прогнозирования. Как делать прогнозы с восстановленной моделью?

Редактировать:

код модели из официального репозитория тензорного потока https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/word2vec/word2vec_basic.py

1 Ответ

0 голосов
/ 31 января 2019

Исходя из того, как вы загружаете контрольную точку, я предполагаю, что это должен быть лучший способ использовать ее для вывода.

Загрузка заполнителей:

input = tf.get_default_graph().get_tensor_by_name("Placeholders/placeholder_name:0")
....

Загрузка операции, которую вы используете длявыполнить прогноз:

prediction = tf.get_default_graph().get_tensor_by_name("SomewhereInsideGraph/prediction_op_name:0")

Создать сеанс, выполнить операцию прогнозирования и передать данные в заполнители.

sess = tf.Session()
sess.run(prediction, feed_dict={input:input_data})

С другой стороны, я предпочитаю всегда создаватьиметь полное создание модели внутри конструктора класса.Затем, я бы сделал следующее:

tf.reset_default_graph()
model = ModelClass()
loader = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
loader.restore(sess, path_to_checkpoint_dir)

Поскольку вы хотите загрузить вложения из обученной модели word2vec в другую модель, вы должны сделать что-то вроде:

embeddings_new_model = tf.Variable(...,name="embeddings")
embedding_saver = tf.train.Saver({"embeddings_word2vec": embeddings_new_model})
with tf.Session() as sess:
    embedding_saver.restore(sess, "word2vec_model_path")

Предполагая, что переменная embeddings в модели word2vec называется embeddings_word2vec.

...