Восстановить обученную модель Tensorflow KeyError: 'BlockLSTM' - PullRequest
0 голосов
/ 11 мая 2018

Итак, я пытаюсь загрузить свою обученную модель Tensorflow, но получаю эту странную ошибку, и я не смог найти никаких ответов относительно этой конкретной ошибки.

Это мой заставочный вызов:

with tf.Session(graph=self.graph) as sess:
    saver = tf.train.Saver()
    for i in range(self.c.epochs):
        batch_data, batch_labels = self.get_batch(train_keys, self.c.doc_len, self.c.num_classes, batch_size=self.c.batch_size)

        _, batch_loss = sess.run([self.optimizer, self.loss], feed_dict={self.input_data: batch_data, self.labels: batch_labels, self.dropout_rate: 0.5})

        if (i % 2 == 0 and i != 0 or i == self.c.epochs-1):
            saver.save(sess, save_model_file, global_step=2)

И это моя функция восстановления:

tf.reset_default_graph()
    saver = tf.train.import_meta_graph(trained_model_name)

    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint('./'))

    graph = tf.get_default_graph()

    X_init = tf.placeholder(tf.float32, shape=(c.vocab_size, c.emb_size))

    input_data = graph.get_tensor_by_name("input_data")

    preds = graph.get_tensor_by_name("preds")

    init = tf.global_variables_initializer()

    sess.run(init, feed_dict={X_init: lexvec_model})

    pred = sess.run(preds, feed_dict={input_data: model_input})

Цель состоит в том, чтобы использовать восстановленную модель, чтобы сделать выводы, но я получаю сообщение об ошибке "saver = tf.train.import_meta_graph (train_model_name)",Некоторая помощь будет отличной :)

Код ошибки:

Traceback (most recent call last):
  File "C:/Users/.../main/Predictor.py", line 94, in <module>
    prediction = predictor.predict(text_doc=doc)
  File "C:/Users/.../main/Predictor.py", line 57, in predict
    saver = tf.train.import_meta_graph(trained_model_name)
  File "C:\Users\...\Python36\lib\site-packages\tensorflow\python\training\saver.py", line 1927, in import_meta_graph **kwargs)
  File "C:\Users\...\Python\Python36\lib\site packages\tensorflow\python\framework\meta_graph.py", line 741, in import_scoped_meta_graph 
    producer_op_list=producer_op_list)
  File "C:\Users\...\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func return func(*args, **kwargs)
  File "C:\Users\...\Python\Python36\lib\site-packages\tensorflow\python\framework\importer.py", line 457, in import_graph_def _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
  File "C:\Users\...\Python\Python36\lib\site-packages\tensorflow\python\framework\importer.py", line 227, in _RemoveDefaultAttrs
    op_def = op_dict[node.op]
KeyError: 'BlockLSTM'

1 Ответ

0 голосов
/ 31 октября 2018

У меня была такая же проблема при использовании LSTMBlockFusedCell ().решение в https://github.com/tensorflow/tensorflow/issues/23369

# for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369
tf.contrib.rnn
# restore meta graph
meta_file = args.restore + '.meta'
loader = tf.train.import_meta_graph(meta_file, clear_devices=True)
...
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...