Как правильно возобновить обучение сети из файла контрольной точки tenorflow? - PullRequest
0 голосов
/ 02 сентября 2018

Я изо всех сил пытаюсь восстановить модель за один день без какого-либо успеха. Мой код состоит из class TF_MLPRegressor(), где я определяю сетевую архитектуру в конструкторе. Затем я вызываю функцию fit() для обучения. Вот так я сохраняю простую модель Perceptron с 1 скрытым слоем в функции fit():

            starting_epoch = 0
            # Launch the graph
            tf.set_random_seed(self.random_state)   # fix the random seed before creating the Session in order to take effect!
            if hasattr(self, 'sess'):
                self.sess.close()
                del self.sess   # delete Session to release memory
                gc.collect()
            self.sess = tf.Session(config=self.config) # save the session to predict from new data
            # Create a saver object which will save all the variables
            saver = tf.train.Saver(max_to_keep=2)  # max_to_keep=2 means to not keep more than 2 checkpoint files
            self.sess.run(tf.global_variables_initializer())

# ... (each 100 epochs)

            saver.save(self.sess, self.checkpoint_dir+"/resume", global_step=epoch)

Затем я создаю новый экземпляр TF_MLPRegressor() с точно такими же значениями входных параметров и вызываю функцию fit() для восстановления модели следующим образом:

    self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
    ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
    starting_epoch = int(ckpt.split('-')[-1])
    metagraph = ".".join([ckpt, 'meta'])
    saver = tf.train.import_meta_graph(metagraph)
    self.sess.run(tf.global_variables_initializer())    # Initialize variables
    lhl = tf.trainable_variables()[2]
    lhlA = lhl.eval(session=self.sess)
    saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model
    lhlB = lhl.eval(session=self.sess)
    print lhlA == lhlB

lhlA и lhlB - последние веса скрытого слоя до и после восстановления, и, согласно моему коду, они полностью совпадают, а именно сохраненная модель не загружается в сеанс. Что я делаю не так?

1 Ответ

0 голосов
/ 03 сентября 2018

Я нашел обходной путь! Как ни странно, метаграф не содержит всех переменных, которые я определил, или присваивает им новые имена. Для примеров в конструкторе я определяю тензоры, которые будут переносить входные векторы объектов и экспериментальные значения:

self.x = tf.placeholder("float", [None, feat_num], name='x')
self.y = tf.placeholder("float", [None], name='y')

Однако, когда я выполняю tf.reset_default_graph() и загружаю метаграф, я получаю следующий список переменных:

[
<tf.Variable 'Variable:0' shape=(300, 300) dtype=float32_ref>, 
<tf.Variable 'Variable_1:0' shape=(300,) dtype=float32_ref>, 
<tf.Variable 'Variable_2:0' shape=(300, 1) dtype=float32_ref>, 
<tf.Variable 'Variable_3:0' shape=(1,) dtype=float32_ref>
]

Для записи каждый вектор входных объектов имеет 300 объектов. Во всяком случае, когда я позже попытаюсь начать обучение, используя:

_, c, p = self.sess.run([self.optimizer, self.cost, self.pred], 
feed_dict={self.x: batch_x, self.y: batch_y, self.isTrain: True})

Я получаю сообщение об ошибке, подобное:

"TypeError: Cannot interpret feed_dict key as Tensor: Tensor 'x' is not an element of this graph."

Итак, поскольку каждый раз, когда я создаю экземпляр class TF_MLPRegressor(), я определяю сетевую архитектуру в конструкторе, я решил не загружать метаграф, и это сработало! Я не знаю, почему TF не сохраняет все переменные в мета-графике, возможно, потому что я явно определяю сетевую архитектуру (я не использую обертки или слои по умолчанию), как в примере ниже:

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

Подводя итог, я сохраняю свои модели, как описано в моем 1-м сообщении, но для их восстановления я использую это:

saver = tf.train.Saver(max_to_keep=2)
self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
self.sess.run(tf.global_variables_initializer())
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...