восстановить тензорную модель потока и запустить ее с вводом - PullRequest
1 голос
/ 04 июля 2019

У меня есть модель, которая принимает int входные данные x и создает среднее значение и дисперсию вектора размера x. Я могу сохранить эту модель, но хочу восстановить, запустив ее, передав значение x. Я также могу восстановить, но не знаю, как выполнить его после строки

saver.restore(sess, './mean_var.ckpt')

Для разных х. Могу ли я использовать feed_dict для этого? Пожалуйста, помогите мне исправить это.

import tensorflow as tf
def mean_var(x):
    vec = tf.random_normal([x])
    mean, variance = tf.nn.moments(vec, [0], keep_dims=True)
    return  mean, variance 
with tf.Graph().as_default():
    x = tf.placeholder(tf.int32)
    output = mean_var(x)
    init = tf.initialize_all_variables()
    _ = tf.Variable(initial_value='fake_variable')
    saver = tf.train.Saver()


    with tf.Session() as sess:
        sess.run(init)
        sess.run(_.initializer)
        #val = sess.run(output, feed_dict={x: 4})
        #print(val[0], val[1])
        save_path = saver.save(sess, "./mean_var.ckpt")

tf.reset_default_graph()

with tf.Graph().as_default():
    init = tf.initialize_all_variables()
    _ = tf.Variable(initial_value='fake_variable')
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init)
        sess.run(_.initializer)
        saver.restore(sess, './mean_var.ckpt')

1 Ответ

0 голосов
/ 04 июля 2019

Используйте это, чтобы восстановить и предсказать:

with tf.Graph().as_default():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./mean_var.ckpt.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name("x:0")   
        output = mean_var(x)
        y_pred = sess.run(output, feed_dict={x:4})
        print(y_pred)

И еще одна вещь дает имя заполнителю x, как показано ниже:

x = tf.placeholder(tf.int32, name="x")

Полный код:

import tensorflow as tf
def mean_var(x):
    vec = tf.random_normal([x])
    mean, variance = tf.nn.moments(vec, [0], keep_dims=True)
    return  mean, variance 

with tf.Graph().as_default():
    x = tf.placeholder(tf.int32, name="x")
    output = mean_var(x)
    init = tf.initialize_all_variables()
    _ = tf.Variable(initial_value='fake_variable')
    saver = tf.train.Saver()


    with tf.Session() as sess:
        sess.run(init)
        sess.run(_.initializer)
        val = sess.run(output, feed_dict={x: 4})
        print(val[0], val[1])
        save_path = saver.save(sess, "./mean_var/mean_var.ckpt")

tf.reset_default_graph()

with tf.Graph().as_default():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./mean_var/mean_var.ckpt.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./mean_var/'))
        #saver.restore(sess, './mean_var/mean_var.ckpt')
        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name("x:0")   
        output = mean_var(x)
        y_pred = sess.run(output, feed_dict={x:4})
        print(y_pred)
...