Я написал программу для загрузки данных по tf.data с Tensorflow 2.1.0. Я хотел ускорить конвейер данных и изучил документ в https://www.tensorflow.org/guide/data_performance#vectorizing_mapping.
. Я хотел бы применить векторизованное отображение для tf.data, и фрагмент кода приведен ниже:
import tensorflow as tf
data = tf.data.TFRecordDataset(['images.tfrecord'])
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'bboxes': tf.io.VarLenFeature(tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
def parse_example(example):
data = tf.io.parse_example(example, image_feature_description)
img = tf.io.decode_jpeg(data['image_raw'])
img = tf.image.resize(img, (416, 416))
bboxes = data['bboxes']
bboxes = tf.sparse.to_dense(bboxes)
bboxes = tf.reshape(bboxes, [-1, 5])
return img, bboxes
#data = data.map(parse_example).batch(1) # this works
data = data.batch(1).map(parse_example) # apply vectorizing mapping but it raises errors
Ошибки перечислены ниже:
Traceback (most recent call last):
File "test_tfrecord.py", line 28, in <module>
data = data.batch(1).map(parse_example)
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 1588, in map
return MapDataset(self, map_func, preserve_cardinality=True)
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3888, in __init__
use_legacy_function=use_legacy_function)
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3147, in __init__
self._function = wrapper_fn._get_concrete_function_internal()
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2395, in _get_concrete_function_internal
*args, **kwargs)
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3140, in wrapper_fn
ret = _wrapper_helper(*args)
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 3082, in _wrapper_helper
ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
File "/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/autograph/impl/api.py", line 237, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:
test_tfrecord.py:16 parse_example *
img = tf.io.decode_jpeg(data['image_raw'])
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_image_ops.py:1092 decode_jpeg
dct_method=dct_method, name=name)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/op_def_library.py:742 _apply_op_helper
attrs=attr_protos, op_def=op_def)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:595 _create_op_internal
compute_device)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:3322 _create_op_internal
op_def=op_def)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1786 __init__
control_input_ops)
/home/wilson/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1622 _create_c_op
raise ValueError(str(e))
ValueError: Shape must be rank 0 but is rank 1 for 'DecodeJpeg' (op: 'DecodeJpeg') with input shapes: [?].
Как это исправить? Спасибо