Как лучше использовать tf.data для создания конвейера согласно официальным рекомендациям? - PullRequest
0 голосов
/ 11 января 2020

Я написал конвейер для подготовки данных OCR для моих сетей. Я заметил руководство по производительности в tenorflow.org . Существует краткое изложение того, как сделать конвейер лучше, но я путаюсь с тем, как улучшить мой трубопровод.

Во-первых, я обнаружил, что страница сводки использует наследование для определения собственного набора данных, но большинство учебников по Официальный сайт таким способом не создается. Это удобнее? Второй - tf.data.Dataset.interleave, использование этой функции сделает мой конвейер лучше? Я хочу провести эксперимент для проверки, но я не знаю, как использовать этот API в моем конвейере ……

class OCRDataLoader():
    def __init__(self, 
                 annotation_paths, 
                 image_height, 
                 image_width, 
                 table_path, 
                 blank_index=0, 
                 batch_size=1, 
                 shuffle=False, 
                 repeat=1):

        imgpaths, labels = self.read_imagepaths_and_labels(annotation_paths)
        self.batch_size = batch_size
        self.image_width = image_width
        self.image_height = image_height
        self.size = len(imgpaths)

        file_init = tf.lookup.TextFileInitializer(table_path, 
                                                  tf.string, 
                                                  tf.lookup.TextFileIndex.WHOLE_LINE,
                                                  tf.int64,
                                                  tf.lookup.TextFileIndex.LINE_NUMBER)
        # Default value for blank label
        self.table = tf.lookup.StaticHashTable(initializer=file_init, default_value=blank_index)

        dataset = tf.data.Dataset.from_tensor_slices((imgpaths, labels))
        if shuffle:
            dataset = dataset.shuffle(buffer_size=self.size)
        dataset = dataset.map(self._decode_and_resize)
        # Ignore the errors e.g. decode error or invalid data.
        dataset = dataset.apply(tf.data.experimental.ignore_errors()) 
        # Pay attention to the location of the batch function.
        dataset = dataset.batch(batch_size)
        dataset = dataset.map(self._convert_label)
        dataset = dataset.repeat(repeat)

        self.dataset = dataset

    def read_imagepaths_and_labels(self, annotation_path):
        """Read txt file to get image paths and labels."""

        imgpaths = []
        labels = []
        for annpath in annotation_path.split(','):
            # If you use your own dataset, maybe you should change the parse code below.
            annotation_folder = os.path.dirname(annpath)
            with open(annpath) as f:
                content = np.array([line.strip().split() for line in f.readlines()])
            imgpaths_local = content[:, 0]
            # Parse MjSynth dataset. format: XX_label_XX.jpg XX
            # URL: https://www.robots.ox.ac.uk/~vgg/data/text/            
            labels_local = [line.split("_")[1] for line in imgpaths_local]

            # Parse example dataset. format: XX.jpg label
            # labels_local = content[:, 1]

            imgpaths_local = [os.path.join(annotation_folder, line) for line in imgpaths_local]
            imgpaths.extend(imgpaths_local)
            labels.extend(labels_local)

        return imgpaths, labels

    def _decode_and_resize(self, filename, label):
        image = tf.io.read_file(filename)
        image = tf.io.decode_jpeg(image, channels=1)
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize(image, [self.image_height, self.image_width])
        return image, label

    def _convert_label(self, image, label):
        chars = tf.strings.unicode_split(label, input_encoding="UTF-8")
        mapped_label = tf.ragged.map_flat_values(self.table.lookup, chars)
        sparse_label = mapped_label.to_sparse()
        sparse_label = tf.cast(sparse_label, tf.int32)
        return image, sparse_label

    def __call__(self):
        """Return tf.data.Dataset."""
        return self.dataset

    def __len__(self):
        return self.size

Этот набор данных будет отображать (пути к изображениям, метку строки) в (данные изображения) , сопоставленный с int), запись по tensflow2.

...