Как использовать модель, обученную на GPU, используя CudnnLSTM на CPU? - PullRequest
0 голосов
/ 02 мая 2018

Tensorflow версии 1.6.0 в Ubuntu 16.04.

Сеть использует CudnnLSTM https://www.tensorflow.org/api_docs/python/tf/contrib/cudnn_rnn/CudnnLSTM

Модель экспорта и прогнозирования работает на графическом процессоре. Но при экспорте и выводе на CPU выдает ошибку ниже.

    File "/home/deepak/.local/lib/python2.7/site-packages/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py", line 501, in _create_saveable
    name="%s_saveable" % self.trainable_variables[0].name.split(":")[0])
  File "/home/deepak/.local/lib/python2.7/site-packages/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py", line 262, in __init__
    weights, biases = self._OpaqueParamsToCanonical()
  File "/home/deepak/.local/lib/python2.7/site-packages/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py", line 315, in _OpaqueParamsToCanonical
    direction=self._direction)
  File "/home/deepak/.local/lib/python2.7/site-packages/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py", line 769, in cudnn_rnn_params_to_canonical
    name=name)
  File "/home/deepak/.local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/deepak/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3290, in create_op
    op_def=op_def)
  File "/home/deepak/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): No OpKernel was registered to support Op 'CudnnRNNParamsToCanonical' with these attrs.  Registered devices: [CPU], Registered kernels:
  <no registered kernels>

     [[Node: CudnnRNNParamsToCanonical = CudnnRNNParamsToCanonical[T=DT_FLOAT, direction="bidirectional", dropout=0, input_mode="linear_input", num_params=16, rnn_mode="lstm", seed=0, seed2=0, _device="/device:GPU:0"](CudnnRNNParamsToCanonical/num_layers, CudnnRNNParamsToCanonical/num_units, CudnnRNNParamsToCanonical/input_size, cudnn_lstm/opaque_kernel/read)]]

И код экспорта как ниже:

with tf.Graph().as_default() as graph:
    inputs, outputs = create_graph()


    # Create a saver using variables from the above newly created graph
    saver = tf.train.Saver(tf.global_variables())

    with tf.Session() as sess:
        # Restore the model from last checkpoints
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        saver.restore(sess, ckpt.model_checkpoint_path)

        # (re-)create export directory
        export_path = os.path.join(
            tf.compat.as_bytes(FLAGS.export_dir),
            tf.compat.as_bytes(str(FLAGS.export_version)))
        if os.path.exists(export_path):
            shutil.rmtree(export_path)

        # create model builder
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)

        input_node = graph.get_tensor_by_name('input_node:0')
        input_lengths = graph.get_tensor_by_name('input_lengths:0')
        outputs = graph.get_tensor_by_name('output_node:0')

        # create tensors info
        predict_tensor_inputs_info = tf.saved_model.utils.build_tensor_info(input_node)
        predict_tensor_inputs_length_info = tf.saved_model.utils.build_tensor_info(input_lengths)
        predict_tensor_scores_info = tf.saved_model.utils.build_tensor_info(outputs)

        # build prediction signature
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'input': predict_tensor_inputs_info,'input_len':predict_tensor_inputs_length_info},
                outputs={'output': predict_tensor_scores_info},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
            )
        )

        # save the model
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'infer': prediction_signature
            })

        builder.save()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...