Загрузка сохраненной модели Tensorflow из ее файла .meta - PullRequest
0 голосов
/ 15 января 2020

Я пытаюсь загрузить мета граф Тензорного потока из сохраненной контрольной точки, используя Tensorflow версии 1.15, чтобы преобразовать его в SavedModel для обслуживания тензорного потока. Это модель распознавания речи с локальным вниманием и однонаправленным LSTM, реализованная с использованием инструментария Returnn с Tensorflow Backend. Я использую следующий код.

import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
import sys

if len(sys.argv)!=2:
        print("Usage:" + sys.argv[0] + "save_dir")
        exit(1)
export_dir=sys.argv[1]
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
sigs={}
with tf.Session(graph=tf.Graph()) as sess:
        new_saver=tf.train.import_meta_graph("./serv_test/model.238.meta")
        new_saver.restore(sess, tf.train.latest_checkpoint("./serv_test"))
        graph=tf.get_default_graph()
        input_audio=graph.get_tensor_by_name('inference/default/wav:0')
        output_hyps=graph.get_tensor_by_name('inference/default/Reshape_7:0')
        sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = tf.saved_model.signature_def_utils.predict_signature_def({"in":input_audio},{"out":output_hyps})
        builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], signature_def_map=sigs,)
builder.save()

Но я получаю следующую ошибку в строке import_meta_graph:

Traceback (most recent call last):
  File "xport.py", line 16, in <module>
    new_saver=tf.train.import_meta_graph("./serv_test/model.238.meta")
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1453, in import_meta_graph
    **kwargs)[0]
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1477, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/framework/meta_graph.py", line 809, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "/home/ubuntu/tf1.15/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered
 'NativeLstm2' in binary running on ip-10-1-21-241. Make sure the Op and Kernel
 are registered in the binary running in this process. Note that if you are loading a
 saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler`
 should be done before importing the graph, as contrib ops are lazily registered when
 the module is first accessed.

Есть ли способ обойти эту ошибку? Это из-за пользовательских слоев, используемых в Returnn? Есть ли способ сделать работоспособный тензор потока Returnn Model? Спасибо.

1 Ответ

0 голосов
/ 09 февраля 2020

Вы должны удалить graph=tf.Graph(), иначе ваш import_meta_graph импортирует его в неправильный график. Просто посмотрите некоторые официальные примеры TF, как использовать import_meta_graph.

...