Я создал базовый чатбот, используя модель Seq2Seq. Бот отлично работает, когда я последовательно запускаю код в своей записной книжке, т.е. строим модель -> обучаем модель -> тестируем модель.
Теперь я хочу сохранить модель после тренировки, загрузить модель и затем протестировать модель.
Однако у меня есть проблемы / я не могу продолжить.
Это то, что я получил до сих пор:
Сохранить модель
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, 'model_final.ckpt')
Кажется, это работает нормально
Загрузить модель
saver = tf.train.import_meta_graph("model_final.ckpt.meta")
graph = tf.get_default_graph()
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
saver.restore(sess, "model_final.ckpt")
Кажется, это работает нормально
Когда я запускаю последовательно, приведенный ниже код выполняет работу по вводу входного вопроса, его токенизации и ответу на вопрос.
prediction_c = tf.argmax(model_c, 2)
result_c = sess_c.run(prediction_c,
feed_dict={enc_input_c: input_batch_c,
dec_input_c: output_batch_c,
targets_c: target_batch_c})
После того, как я загрузил модель Seq2Seq, я не уверен, как переменные, такие как model_c, input_c, получают значения / инициализируются.
Я прошу прощения за основную природу вопроса или если то, что я пытаюсь достичь, не имеет смысла; Я только начинаю тензоры.