Тензорный поток в CoreML с tf-coreml - PullRequest
0 голосов
/ 01 октября 2019

У меня есть сеть с несколькими входами, которая использует tf.bool tf.placeholder для управления выполнением нормализации партии при обучении и валидации / тестировании. Я пытался преобразовать эту обученную модель в CoreML через tf-coreml библиотеку безуспешно, с ошибкой ниже:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value

Я понимаю, что эта ошибка говорит о том, чтоэто определенный узел, в котором отсутствует значение, поэтому конвертер может выполнить модель. Я также понимаю, что эта ошибка связана с операциями потока управления (связана с методом пакетной нормализации, создающим такие операции, как Switch и Merge). Исходный код показывает это:

def testSwitchDeadBranch(self):
    with self.cached_session():
      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = ops.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      dead_branch = array_ops.identity(switch_op[0])

      with self.assertRaisesWithPredicateMatch(
          errors_impl.InvalidArgumentError,
          lambda e: "Retval[0] does not have value" in str(e)):
        self.evaluate(dead_branch)

Обратите внимание, что моя ошибка Retval[26] (я получил [24] и т. Д.), А не Retval[0]. Я предполагаю, что он проверяет Switch «мертвую ветвь», которая должна быть неиспользуемой ветвью для вывода. Код также делает то же самое с Merge «мертвой ветвью».

Есть ли какие-то детали, которые мне не хватает, которые могут быть причиной этой ошибки (конечно, не первая ошибка, с которой я столкнулся во время преобразования)? Как сделан вывод? Способ нормализации партии реализован? Как модель сохраняется?

Что я уже сделал:

  • Я использую Tensorflow 1.14.0
  • Я знаю tf.layers.batch_normalization создает операции Switch и Merge, которые не совместимы с CoreML
  • Я пытался преобразовать в Tensorflow Lite с аналогичными проблемами
  • Я следую Facenet (эта модель использует ту же логику tf.bool для обучения,проверка, тестирование) процесс преобразования не увенчался успехом
  • я пробовал GraphTransforms библиотеку
  • я пробовал скрипты для удаления / изменения потока управления
  • Я создал отдельные графики, чтобы избежать лишних операций, но безуспешно

Примечание: я абстрагировал большую часть кода, чтобы опубликовать этот вопрос.

Thisкак осуществляется пакетная нормализация (в блоке свертки).

training = tf.placeholder(tf.bool, shape = (), name = 'training')

def conv_layer(input, kernelSize, nFilters, poolSize, stride, input_channels = 1, name = 'conv'):
        with tf.name_scope(name):
        shape = [kernelSize, kernelSize, input_channels, nFilters]
        weights = new_weights(shape = shape)        biases = new_biases(length = nFilters)
        conv = tf.nn.conv2d(input, weights, strides = [1, 2, 2, 1], padding = 'SAME', name = 'convL')
        conv += biases
        pool = tf.reduce_max(conv, reduction_indices=[3], keep_dims=True, name = 'pool') 
       pool = tf.nn.max_pool(conv, ksize = [1, poolSize, poolSize, 1], strides = shape, padding = 'SAME')
        bnorm = tf.layers.batch_normalization(pool, training = training, center = True, scale = True, fused = False, reuse= False)
        act = tf.nn.relu(bnorm)
        return act

Ниже приведен код для обучения и сохранения модели.

saver = tf.train.Saver()

    with tf.Session(config = config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(init_train_op)

        for epoch in range(MAX_EPOCHS):

            for step in range(10):

                l, _, se = sess.run(
                    [loss, train_op, mean_squared_error],
                     feed_dict = {training: True})

            print('\nRunning validation operation...')

            sess.run(init_val_op)
            for _ in range(10):
                val_out, val_l, val_se = sess.run(
                    [out, val_loss, val_mean_squared_error],
                    feed_dict = {training: False})

            sess.run(init_train_op) # switch back to training set

        #Save model
        print('Saving Model...\n')
        saver.save(sess, join(saveDir, './model_saver_validation'.format(modelIndex)), write_meta_graph = True)

Ниже приведен код для загрузки, обновленияввод, выполнить вывод и заморозить модель.

# Dummy data for inference
b = np.zeros((1, 80, 160, 1), np.float32)
ill = np.ones((1,3), np.float32)
is_train = False

def freeze():
    with tf.Graph().as_default():
        with tf.Session() as sess:
            bIn = tf.placeholder(dtype=tf.float32, shape=[
                             1, 80, 160, 1], name='bIn')
            illumIn = tf.placeholder(dtype=tf.float32, shape=[
                                     1, 3], name='illumIn')
            training = tf.placeholder(tf.bool, shape=(), name = 'training')

            # Load the model metagraph and checkpoint
            meta_file = meta_graph #.meta file from saver.save()
            ckpt_file = checkpoint_file #checkpoint file

            # Load graph to redirect inputs from iterator to expected inputs
            saver = tf.train.import_meta_graph(meta_file, input_map={
                'IteratorGetNext:0': bIn,
                'IteratorGetNext:3': illumIn,
                'training:0': training},  clear_devices = True)

            tf.get_default_session().run(tf.global_variables_initializer())
            tf.get_default_session().run(tf.local_variables_initializer())
            saver.restore(tf.get_default_session(), ckpt_file)

            pred = tf.get_default_graph().get_tensor_by_name('Out:0')

            tf.get_default_session().run(pred, feed_dict={'bIn:0': b, 'poseIn:0': po, 'training:0': is_train})

            # Retrieve the protobuf graph definition and fix the batch norm nodes
            input_graph_def = sess.graph.as_graph_def()

            # Freeze the graph def
            output_graph_def = freeze_graph_def(
                sess, input_graph_def, output_node_names)

        # Serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(frozen_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())

freeze()

Ниже приведен код для преобразования в CoreML.

tfcoreml.convert(
    tf_model_path=frozen_graph,
    mlmodel_path='./coreml_model.mlmodel',
    output_feature_names=['Out:0'],
    input_name_shape_dict={
        'bIn:0': [1, 80, 160, 1],
        'illumIn:0': [1, 3], 
        'training:0': []})

Нижеэто ошибка, выданная tf-coreml.

Loading the TF graph...
Graph Loaded.
Collecting all the 'Const' ops from the graph, by running it....

Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tf2opencv.py", line 392, in <module>
    'illumIn:0': [1, 3], 'poseIn:0': [1, 16], 'training:0': []})
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tfcoreml/_tf_coreml_converter.py", line 586, in convert
    custom_conversion_functions=custom_conversion_functions)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tfcoreml/_tf_coreml_converter.py", line 243, in _convert_pb_to_mlmodel
    tensors_evaluated = sess.run(tensors, feed_dict=input_feed_dict)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value
...