TensorFlow: восстановленные переменные кажутся случайными - PullRequest
0 голосов
/ 25 апреля 2018

У меня проблема с восстановлением некоторых переменных.Я уже восстановил переменные, когда сохранил всю модель на более высоком уровне, но на этот раз я решил восстановить только несколько переменных.Перед первым сеансом я инициализирую веса:

weights = {
'1': tf.Variable(tf.random_normal([n_input, n_hidden_1], mean=0, stddev=tf.sqrt(2*1.67/(n_input+n_hidden_1))), name='w1')
}
weights_saver = tf.train.Saver(var_list=weights)

Затем в сеансе, пока я тренирую NN:

with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
[...]
weights_saver.save(sess, './savedModels/Weights/weights')

Затем:

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph(pathsToVariables + 'Weights/weights.meta')
    new_saver.restore(sess, pathsToVariables + 'Weights/weights')

    weights = 
    {
    '1': tf.Variable(sess.graph.get_tensor_by_name("w1:0"), name='w1', trainable=False)
    }

    sess.run(tf.global_variables_initializer())
    print(sess.run(weights['1']))

Но на этом этапе восстановленные веса кажутся случайными.И действительно, если я сделаю sess.run(tf.global_variables_initializer()) снова, вес будет другим.Как будто я восстановил нормальную функцию инициализации весов, но не обученных весов.

Что я делаю не так?

Понятна ли моя проблема?

1 Ответ

0 голосов
/ 02 мая 2018
 weights = 
    {
    '1': tf.Variable(sess.run(sess.graph.get_tensor_by_name("w1:0")), name='w1', trainable=False)
    }

Я узнал ответ. Мне нужно было запустить тензор, чтобы получить значения. Теперь это кажется очевидным.

изменить 2:

Этот способ не является хорошим способом инициализации тензоров из других значений, потому что он создаст 2 тензора с одинаковым именем при восстановлении, а затем создаст тензор. Или, если разные имена, он восстановит переменную из прошлой модели и может попытаться оптимизировать ее позже. Лучше восстановить переменную в предыдущем сеансе, сохранить значения, затем закрыть сеанс, открыть новый, чтобы создать все остальное.

 with tf.session() as sess: 
    weight1 = sess.run(sess.graph.get_tensor_by_name("w1:0"))

 tf.reset_default_graph() #this will eliminate the variables we restored

 with tf.session() as sess:
    weights = 
       {
       '1': tf.Variable(weight1 , name='w1-bis', trainable=False)
       }
...

Теперь мы уверены, что восстановленные переменные не являются частью графика.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...