Итак, я пытаюсь загрузить свою обученную модель Tensorflow, но получаю эту странную ошибку, и я не смог найти никаких ответов относительно этой конкретной ошибки.
Это мой заставочный вызов:
with tf.Session(graph=self.graph) as sess:
saver = tf.train.Saver()
for i in range(self.c.epochs):
batch_data, batch_labels = self.get_batch(train_keys, self.c.doc_len, self.c.num_classes, batch_size=self.c.batch_size)
_, batch_loss = sess.run([self.optimizer, self.loss], feed_dict={self.input_data: batch_data, self.labels: batch_labels, self.dropout_rate: 0.5})
if (i % 2 == 0 and i != 0 or i == self.c.epochs-1):
saver.save(sess, save_model_file, global_step=2)
И это моя функция восстановления:
tf.reset_default_graph()
saver = tf.train.import_meta_graph(trained_model_name)
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
X_init = tf.placeholder(tf.float32, shape=(c.vocab_size, c.emb_size))
input_data = graph.get_tensor_by_name("input_data")
preds = graph.get_tensor_by_name("preds")
init = tf.global_variables_initializer()
sess.run(init, feed_dict={X_init: lexvec_model})
pred = sess.run(preds, feed_dict={input_data: model_input})
Цель состоит в том, чтобы использовать восстановленную модель, чтобы сделать выводы, но я получаю сообщение об ошибке "saver = tf.train.import_meta_graph (train_model_name)",Некоторая помощь будет отличной :)
Код ошибки:
Traceback (most recent call last):
File "C:/Users/.../main/Predictor.py", line 94, in <module>
prediction = predictor.predict(text_doc=doc)
File "C:/Users/.../main/Predictor.py", line 57, in predict
saver = tf.train.import_meta_graph(trained_model_name)
File "C:\Users\...\Python36\lib\site-packages\tensorflow\python\training\saver.py", line 1927, in import_meta_graph **kwargs)
File "C:\Users\...\Python\Python36\lib\site packages\tensorflow\python\framework\meta_graph.py", line 741, in import_scoped_meta_graph
producer_op_list=producer_op_list)
File "C:\Users\...\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func return func(*args, **kwargs)
File "C:\Users\...\Python\Python36\lib\site-packages\tensorflow\python\framework\importer.py", line 457, in import_graph_def _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
File "C:\Users\...\Python\Python36\lib\site-packages\tensorflow\python\framework\importer.py", line 227, in _RemoveDefaultAttrs
op_def = op_dict[node.op]
KeyError: 'BlockLSTM'