Я написал конвейер для подготовки данных OCR для моих сетей. Я заметил руководство по производительности в tenorflow.org . Существует краткое изложение того, как сделать конвейер лучше, но я путаюсь с тем, как улучшить мой трубопровод.
Во-первых, я обнаружил, что страница сводки использует наследование для определения собственного набора данных, но большинство учебников по Официальный сайт таким способом не создается. Это удобнее? Второй - tf.data.Dataset.interleave
, использование этой функции сделает мой конвейер лучше? Я хочу провести эксперимент для проверки, но я не знаю, как использовать этот API в моем конвейере ……
class OCRDataLoader():
def __init__(self,
annotation_paths,
image_height,
image_width,
table_path,
blank_index=0,
batch_size=1,
shuffle=False,
repeat=1):
imgpaths, labels = self.read_imagepaths_and_labels(annotation_paths)
self.batch_size = batch_size
self.image_width = image_width
self.image_height = image_height
self.size = len(imgpaths)
file_init = tf.lookup.TextFileInitializer(table_path,
tf.string,
tf.lookup.TextFileIndex.WHOLE_LINE,
tf.int64,
tf.lookup.TextFileIndex.LINE_NUMBER)
# Default value for blank label
self.table = tf.lookup.StaticHashTable(initializer=file_init, default_value=blank_index)
dataset = tf.data.Dataset.from_tensor_slices((imgpaths, labels))
if shuffle:
dataset = dataset.shuffle(buffer_size=self.size)
dataset = dataset.map(self._decode_and_resize)
# Ignore the errors e.g. decode error or invalid data.
dataset = dataset.apply(tf.data.experimental.ignore_errors())
# Pay attention to the location of the batch function.
dataset = dataset.batch(batch_size)
dataset = dataset.map(self._convert_label)
dataset = dataset.repeat(repeat)
self.dataset = dataset
def read_imagepaths_and_labels(self, annotation_path):
"""Read txt file to get image paths and labels."""
imgpaths = []
labels = []
for annpath in annotation_path.split(','):
# If you use your own dataset, maybe you should change the parse code below.
annotation_folder = os.path.dirname(annpath)
with open(annpath) as f:
content = np.array([line.strip().split() for line in f.readlines()])
imgpaths_local = content[:, 0]
# Parse MjSynth dataset. format: XX_label_XX.jpg XX
# URL: https://www.robots.ox.ac.uk/~vgg/data/text/
labels_local = [line.split("_")[1] for line in imgpaths_local]
# Parse example dataset. format: XX.jpg label
# labels_local = content[:, 1]
imgpaths_local = [os.path.join(annotation_folder, line) for line in imgpaths_local]
imgpaths.extend(imgpaths_local)
labels.extend(labels_local)
return imgpaths, labels
def _decode_and_resize(self, filename, label):
image = tf.io.read_file(filename)
image = tf.io.decode_jpeg(image, channels=1)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [self.image_height, self.image_width])
return image, label
def _convert_label(self, image, label):
chars = tf.strings.unicode_split(label, input_encoding="UTF-8")
mapped_label = tf.ragged.map_flat_values(self.table.lookup, chars)
sparse_label = mapped_label.to_sparse()
sparse_label = tf.cast(sparse_label, tf.int32)
return image, sparse_label
def __call__(self):
"""Return tf.data.Dataset."""
return self.dataset
def __len__(self):
return self.size
Этот набор данных будет отображать (пути к изображениям, метку строки) в (данные изображения) , сопоставленный с int), запись по tensflow2.