Неверный вывод для восстановленной переменной на графике тензорного потока - PullRequest
0 голосов
/ 02 июля 2018

В настоящее время я играю с сохранением и восстановлением переменных. Для этого я создал два скрипта. Один из них сохраняет простой график, а другой восстанавливает его. Вот тестовый скрипт для сохранения графика:

import tensorflow as tf

a = tf.Variable(3.0, name='a')
b = tf.Variable(5.0, name='b')

b = tf.assign_add(b, a)

n_steps = 5

global_step = tf.Variable(0, name='global_step', trainable=False)

saver = tf.train.Saver()

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    for step in range(n_steps):
        print(sess.run(b))

        global_step.assign_add(1).eval()
        print(global_step.eval())

        saver.save(sess, './my_test_model', global_step=global_step)

По сути, я хочу выполнить цикл 5 раз, и каждый раз, когда я делаю это, я добавляю a к b. Я также хочу отслеживать количество шагов с помощью global_step. Это работает как задумано. Выход:

8.0     # value of b
1       # step
11.0
2
14.0
3
17.0
4
20.0
5

Теперь при восстановлении переменных я пытаюсь получить все три из них. Сценарий:

import tensorflow as tf

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# List ALL tensors.
print_tensors_in_checkpoint_file(tf.train.latest_checkpoint('./'), all_tensors=True, tensor_name='')

tf.reset_default_graph()

a = tf.get_variable('a', shape=[])
b = tf.get_variable('b', shape=[])
global_step = tf.get_variable('global_step', shape=[])

saver = tf.train.Saver()

with tf.Session() as sess:

    ckpt = tf.train.latest_checkpoint('./')
    if ckpt:
        print(ckpt)

        saver.restore(sess, ckpt)

    else:
        print('Nothing restored')

    print(a.eval())
    print(b.eval())
    print(global_step.eval())

Вывод этого

tensor_name:  a
3.0
tensor_name:  b
20.0
tensor_name:  global_step
5
./my_test_model-5
INFO:tensorflow:Restoring parameters from ./my_test_model-5
3.0
20.0
7e-45

Как это возможно, что значение global_step правильно хранится в контрольной точке, но после оценки я получаю этот маленький 7e-45 ? Кроме того, после восстановления мне кажется, что я не могу определить какие-либо дополнительные переменные, поскольку он утверждает, что не может найти переменную в контрольной точке. Как, например, определить переменную и добавить ее к b восстановленного графа?

Спасибо за помощь!

1 Ответ

0 голосов
/ 02 июля 2018

Это не очень хорошо документировано в документах TF, но вы должны указать dtype для переменной global_step.

Некорректное

global_step = tf.get_variable('global_step', shape=[], dtype=tf.float32) результаты в global_step=7e-5. Предполагается, что типом по умолчанию является dtf.float32.

Корректное

global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32) результаты в global_step=5

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...