Не удалось применить отображение векторизации для tf.data в Tensorflow 2.1.0 - PullRequest
0 голосов
/ 16 апреля 2020

Я написал программу для загрузки данных по tf.data с Tensorflow 2.1.0. Я хотел ускорить конвейер данных и изучил документ в https://www.tensorflow.org/guide/data_performance#vectorizing_mapping.

. Я хотел бы применить векторизованное отображение для tf.data, и фрагмент кода приведен ниже:

import tensorflow as tf

data = tf.data.TFRecordDataset(['images.tfrecord'])

image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'bboxes': tf.io.VarLenFeature(tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def parse_example(example):
    data = tf.io.parse_example(example, image_feature_description)

    img = tf.io.decode_jpeg(data['image_raw'])

    img = tf.image.resize(img, (416, 416))

    bboxes = data['bboxes']
    bboxes = tf.sparse.to_dense(bboxes)
    bboxes = tf.reshape(bboxes, [-1, 5])

    return img, bboxes


#data = data.map(parse_example).batch(1)  # this works
data = data.batch(1).map(parse_example)  # apply vectorizing mapping but it raises errors

Ошибки перечислены ниже:

Traceback (most recent call last):
  File "test_tfrecord.py", line 28, in <module>
    data = data.batch(1).map(parse_example)
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 1588, in map
    return MapDataset(self, map_func, preserve_cardinality=True)
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3888, in __init__
    use_legacy_function=use_legacy_function)
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3147, in __init__
    self._function = wrapper_fn._get_concrete_function_internal()
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2395, in _get_concrete_function_internal
    *args, **kwargs)
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3140, in wrapper_fn
    ret = _wrapper_helper(*args)
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3082, in _wrapper_helper
    ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
  File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/autograph/impl/api.py", line 237, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:

test_tfrecord.py:16 parse_example  *
    img = tf.io.decode_jpeg(data['image_raw'])
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_image_ops.py:1092 decode_jpeg
    dct_method=dct_method, name=name)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/op_def_library.py:742 _apply_op_helper
    attrs=attr_protos, op_def=op_def)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:595 _create_op_internal
    compute_device)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:3322 _create_op_internal
    op_def=op_def)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1786 __init__
    control_input_ops)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1622 _create_c_op
    raise ValueError(str(e))

ValueError: Shape must be rank 0 but is rank 1 for 'DecodeJpeg' (op: 'DecodeJpeg') with input shapes: [?].

Как это исправить? Спасибо

1 Ответ

0 голосов
/ 16 апреля 2020

Вы должны сначала проанализировать пример, а затем пакетную:

data = data.map(parse_example).batch(1)

В противном случае размеры не соответствуют вашему парсеру.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...