Чтение tfrecord во время пакетной обработки с одним сеансом tf - PullRequest
0 голосов
/ 19 сентября 2018

Я пытаюсь приспособить в Keras генератор моделей данных для метода model.fit_generator().Смысл в том, чтобы прочитать из tfrecord изображение с некоторым индексом во время создания пакета.

Итак, у меня есть мой генератор объектов:

class DataGeneratorCustom:

    def __init__(self, ...):
    ...

    def generate(self,
                 batch_size=32,
                 shuffle=True,
                 transformations=[],
                 label_encoder=None,
                 returns={'processed_images', 'encoded_labels'},
                 keep_images_without_gt=False,
                 degenerate_box_handling='remove'):

        '''
        Yields:
            The next batch as a tuple of items as defined by the `returns` argument.
        '''

        if self.dataset_size == 0:
            raise DatasetError("Cannot generate batches because you did not load a dataset.")

        #############################################################################################
        # Warn if any of the set returns aren't possible.
        #############################################################################################

        if self.labels is None:
            if any([ret in returns for ret in ['original_labels', 'processed_labels', 'encoded_labels', 'matched_anchors', 'evaluation-neutral']]):
                warnings.warn("Since no labels were given, none of 'original_labels', 'processed_labels', 'evaluation-neutral', 'encoded_labels', and 'matched_anchors' " +
                              "are possible returns, but you set `returns = {}`. The impossible returns will be `None`.".format(returns))
        elif label_encoder is None:
            if any([ret in returns for ret in ['encoded_labels', 'matched_anchors']]):
                warnings.warn("Since no label encoder was given, 'encoded_labels' and 'matched_anchors' aren't possible returns, " +
                              "but you set `returns = {}`. The impossible returns will be `None`.".format(returns))
        elif not isinstance(label_encoder, SSDInputEncoder):
            if 'matched_anchors' in returns:
                warnings.warn("`label_encoder` is not an `SSDInputEncoder` object, therefore 'matched_anchors' is not a possible return, " +
                              "but you set `returns = {}`. The impossible returns will be `None`.".format(returns))

        #############################################################################################
        # Do a few preparatory things like maybe shuffling the dataset initially.
        #############################################################################################

        if shuffle:
            objects_to_shuffle = [self.dataset_indices]
            if not (self.filenames is None):
                objects_to_shuffle.append(self.filenames)
            if not (self.labels is None):
                objects_to_shuffle.append(self.labels)
            if not (self.image_ids is None):
                objects_to_shuffle.append(self.image_ids)
            if not (self.eval_neutral is None):
                objects_to_shuffle.append(self.eval_neutral)
            shuffled_objects = sklearn.utils.shuffle(*objects_to_shuffle)
            for i in range(len(objects_to_shuffle)):
                objects_to_shuffle[i][:] = shuffled_objects[i]

        if degenerate_box_handling == 'remove':
            box_filter = BoxFilter(check_overlap=False,
                                   check_min_area=False,
                                   check_degenerate=True,
                                   labels_format=self.labels_format)

        # Override the labels formats of all the transformations to make sure they are set correctly.
        if not (self.labels is None):
            for transform in transformations:
                transform.labels_format = self.labels_format

        #############################################################################################
        # Generate mini batches.
        #############################################################################################

        current = 0

        while True:

            batch_X, batch_y = [], []

            if current >= self.dataset_size:
                current = 0

            #########################################################################################
            # Maybe shuffle the dataset if a full pass over the dataset has finished.
            #########################################################################################

                if shuffle:
                    objects_to_shuffle = [self.dataset_indices]
                    if not (self.filenames is None):
                        objects_to_shuffle.append(self.filenames)
                    if not (self.labels is None):
                        objects_to_shuffle.append(self.labels)
                    if not (self.image_ids is None):
                        objects_to_shuffle.append(self.image_ids)
                    if not (self.eval_neutral is None):
                        objects_to_shuffle.append(self.eval_neutral)
                    shuffled_objects = sklearn.utils.shuffle(*objects_to_shuffle)
                    for i in range(len(objects_to_shuffle)):
                        objects_to_shuffle[i][:] = shuffled_objects[i]

            #########################################################################################
            # Get the images, (maybe) image IDs, (maybe) labels, etc. for this batch.
            #########################################################################################

            # We prioritize our options in the following order:
            # 1) If we have the images already loaded in memory, get them from there.
            # 2) Else, if we have an TFRecord dataset, get the images from there.
            # 3) Else, if we have neither of the above, we'll have to load the individual image
            #    files from disk.
            batch_indices = self.dataset_indices[current:current+batch_size]
            if not (self.images is None):
                for i in batch_indices:
                    batch_X.append(self.images[i])
                if not (self.filenames is None):
                    batch_filenames = self.filenames[current:current+batch_size]
                else:
                    batch_filenames = None
            # elif not (self.hdf5_dataset is None):
            #     for i in batch_indices:
            #         batch_X.append(self.hdf5_dataset['images'][i].reshape(self.hdf5_dataset['image_shapes'][i]))
            elif not (self.tfrecord_dataset is None):
                for i in batch_indices:
                    image, image_shape = self.tfrecord_extract_image(i)
                    batch_X.append(image.reshape(image_shape))
                    # batch_X.append(self.hdf5_dataset['images'][i].reshape(self.hdf5_dataset['image_shapes'][i]))
                if not (self.filenames is None):
                    batch_filenames = self.filenames[current:current+batch_size]
                else:
                    batch_filenames = None
            else:
                batch_filenames = self.filenames[current:current+batch_size]
                for filename in batch_filenames:
                    with Image.open(filename) as image:
                        batch_X.append(np.array(image, dtype=np.uint8))

            # Get the labels for this batch (if there are any).
            if not (self.labels is None):
                batch_y = deepcopy(self.labels[current:current+batch_size])
            else:
                batch_y = None

            if not (self.eval_neutral is None):
                batch_eval_neutral = self.eval_neutral[current:current+batch_size]
            else:
                batch_eval_neutral = None

            # Get the image IDs for this batch (if there are any).
            if not (self.image_ids is None):
                batch_image_ids = self.image_ids[current:current+batch_size]
            else:
                batch_image_ids = None

            if 'original_images' in returns:
                batch_original_images = deepcopy(batch_X) # The original, unaltered images
            if 'original_labels' in returns:
                batch_original_labels = deepcopy(batch_y) # The original, unaltered labels

            current += batch_size

            #########################################################################################
            # Maybe perform image transformations.
            #########################################################################################

            batch_items_to_remove = [] # In case we need to remove any images from the batch, store their indices in this list.
            batch_inverse_transforms = []

            for i in range(len(batch_X)):

                if not (self.labels is None):
                    # Convert the labels for this image to an array (in case they aren't already).
                    batch_y[i] = np.array(batch_y[i])
                    # If this image has no ground truth boxes, maybe we don't want to keep it in the batch.
                    if (batch_y[i].size == 0) and not keep_images_without_gt:
                        batch_items_to_remove.append(i)
                        batch_inverse_transforms.append([])
                        continue

                # Apply any image transformations we may have received.
                if transformations:

                    inverse_transforms = []

                    for transform in transformations:

                        if not (self.labels is None):

                            if ('inverse_transform' in returns) and ('return_inverter' in inspect.signature(transform).parameters):
                                batch_X[i], batch_y[i], inverse_transform = transform(batch_X[i], batch_y[i], return_inverter=True)
                                inverse_transforms.append(inverse_transform)
                            else:
                                batch_X[i], batch_y[i] = transform(batch_X[i], batch_y[i])

                            if batch_X[i] is None: # In case the transform failed to produce an output image, which is possible for some random transforms.
                                batch_items_to_remove.append(i)
                                batch_inverse_transforms.append([])
                                continue

                        else:

                            if ('inverse_transform' in returns) and ('return_inverter' in inspect.signature(transform).parameters):
                                batch_X[i], inverse_transform = transform(batch_X[i], return_inverter=True)
                                inverse_transforms.append(inverse_transform)
                            else:
                                batch_X[i] = transform(batch_X[i])

                    batch_inverse_transforms.append(inverse_transforms[::-1])

                #########################################################################################
                # Check for degenerate boxes in this batch item.
                #########################################################################################

                if not (self.labels is None):

                    xmin = self.labels_format['xmin']
                    ymin = self.labels_format['ymin']
                    xmax = self.labels_format['xmax']
                    ymax = self.labels_format['ymax']

                    if np.any(batch_y[i][:,xmax] - batch_y[i][:,xmin] <= 0) or np.any(batch_y[i][:,ymax] - batch_y[i][:,ymin] <= 0):
                        if degenerate_box_handling == 'warn':
                            warnings.warn("Detected degenerate ground truth bounding boxes for batch item {} with bounding boxes {}, ".format(i, batch_y[i]) +
                                          "i.e. bounding boxes where xmax <= xmin and/or ymax <= ymin. " +
                                          "This could mean that your dataset contains degenerate ground truth boxes, or that any image transformations you may apply might " +
                                          "result in degenerate ground truth boxes, or that you are parsing the ground truth in the wrong coordinate format." +
                                          "Degenerate ground truth bounding boxes may lead to NaN errors during the training.")
                        elif degenerate_box_handling == 'remove':
                            batch_y[i] = box_filter(batch_y[i])
                            if (batch_y[i].size == 0) and not keep_images_without_gt:
                                batch_items_to_remove.append(i)

            #########################################################################################
            # Remove any items we might not want to keep from the batch.
            #########################################################################################

            if batch_items_to_remove:
                for j in sorted(batch_items_to_remove, reverse=True):
                    # This isn't efficient, but it hopefully shouldn't need to be done often anyway.
                    batch_X.pop(j)
                    batch_filenames.pop(j)
                    if batch_inverse_transforms: batch_inverse_transforms.pop(j)
                    if not (self.labels is None): batch_y.pop(j)
                    if not (self.image_ids is None): batch_image_ids.pop(j)
                    if not (self.eval_neutral is None): batch_eval_neutral.pop(j)
                    if 'original_images' in returns: batch_original_images.pop(j)
                    if 'original_labels' in returns and not (self.labels is None): batch_original_labels.pop(j)

            #########################################################################################

            # CAUTION: Converting `batch_X` into an array will result in an empty batch if the images have varying sizes
            #          or varying numbers of channels. At this point, all images must have the same size and the same
            #          number of channels.
            batch_X = np.array(batch_X)
            if (batch_X.size == 0):
                raise DegenerateBatchError("You produced an empty batch. This might be because the images in the batch vary " +
                                           "in their size and/or number of channels. Note that after all transformations " +
                                           "(if any were given) have been applied to all images in the batch, all images " +
                                           "must be homogenous in size along all axes.")

            #########################################################################################
            # If we have a label encoder, encode our labels.
            #########################################################################################

            if not (label_encoder is None or self.labels is None):

                if ('matched_anchors' in returns) and isinstance(label_encoder, SSDInputEncoder):
                    batch_y_encoded, batch_matched_anchors = label_encoder(batch_y, diagnostics=True)
                else:
                    batch_y_encoded = label_encoder(batch_y, diagnostics=False)
                    batch_matched_anchors = None

            else:
                batch_y_encoded = None
                batch_matched_anchors = None

            #########################################################################################
            # Compose the output.
            #########################################################################################

            ret = []
            if 'processed_images' in returns: ret.append(batch_X)
            if 'encoded_labels' in returns: ret.append(batch_y_encoded)
            if 'matched_anchors' in returns: ret.append(batch_matched_anchors)
            if 'processed_labels' in returns: ret.append(batch_y)
            if 'filenames' in returns: ret.append(batch_filenames)
            if 'image_ids' in returns: ret.append(batch_image_ids)
            if 'evaluation-neutral' in returns: ret.append(batch_eval_neutral)
            if 'inverse_transform' in returns: ret.append(batch_inverse_transforms)
            if 'original_images' in returns: ret.append(batch_original_images)
            if 'original_labels' in returns: ret.append(batch_original_labels)

            yield ret

    def tfrecord_extract_image(self,
                               index):

        # tf.keras.backend.clear_session()
        iterator = self.tfrecord_dataset.make_one_shot_iterator()
        next_record = iterator.get_next()

        # with tf.Graph().as_default():
        # with tf.keras.backend.get_session() as session:

        # Iterate with a tensorflow-session
        # with self.session.as_default() as default_session:

        # Jump to the record of the index
        if index > 0:
            for i in range(index):
                # K.get_session().run(next_record)
                # session.run(next_record)
                self.session.run(next_record)

        # Extract and return the image
        # image, labels, image_shape, labels_shape, image_id, eval_neutral = session.run(next_record)
        # image, labels, image_shape, labels_shape, image_id, eval_neutral = K.get_session().run(next_record)
        image, labels, image_shape, labels_shape, image_id, eval_neutral = self.session.run(next_record)

        # Decode the fields
        image_shape = tf.decode_raw(image_shape, tf.int32)
        image_shape = image_shape.eval()
        image = tf.decode_raw(image, tf.uint8)
        image = image.eval()
        image = image.reshape(image_shape)

        return image, image_shape

Этот генератор является внешнимпередается модели через fit_generator ():

history = model.fit_generator(generator=train_generator,
                          steps_per_epoch=steps_per_epoch,
                          epochs=final_epoch,
                          callbacks=callbacks,
                          validation_data=val_generator,
                          validation_steps=ceil(val_dataset_size/batch_size),
                          initial_epoch=initial_epoch)

Единственный фрагмент кода, который вызывает у меня проблему, это tfrecord_extract_image().Для чтения записи мне нужен tf.Session(), и, действительно, используя ключевое слово with в tf.Session(), я могу прочитать tfrecord:

def tfrecord_extract_image(self,
                       index):

    # tf.keras.backend.clear_session()
    # tf.keras.backend.clear_session()
    iterator = self.tfrecord_dataset.make_one_shot_iterator()
    next_record = iterator.get_next()

    with tf.Session() as session:

        # Jump to the record of the index
        if index > 0:
            for i in range(index):
                # K.get_session().run(next_record)
                # session.run(next_record)
                session.run(next_record)

        # Extract and return the image
        # image, labels, image_shape, labels_shape, image_id, eval_neutral = session.run(next_record)
        # image, labels, image_shape, labels_shape, image_id, eval_neutral = K.get_session().run(next_record)
        image, labels, image_shape, labels_shape, image_id, eval_neutral = session.run(next_record)

        # Decode the fields
        image_shape = tf.decode_raw(image_shape, tf.int32)
        image_shape = image_shape.eval()
        image = tf.decode_raw(image, tf.uint8)
        image = image.eval()
        image = image.reshape(image_shape)

        return image, image_shape

Итак, для каждого поиска я использую сеанс, но это дает мне много проблем, когда я использую GoogleML .Действительно, облачный компьютер вынужден создавать новый экземпляр GPU на каждом шаге пакета:

12/100 [==>...........................] - ETA: 9:46 - loss: 32.8772  master-replica-0
Adding visible gpu devices: 0  master-replica-0
Device interconnect StreamExecutor with strength 1 edge matrix:  master-replica-0
  0   master-replica-0
  N   master-replica-0
Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10763 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7)  master-replica-0

13/100 [==>...........................] - ETA: 9:21 - loss: 32.8790  master-replica-0
Adding visible gpu devices: 0  master-replica-0
Device interconnect StreamExecutor with strength 1 edge matrix:  master-replica-0
  0   master-replica-0
  N   master-replica-0
Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10763 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7)  master-replica-0

 14/100 [===>..........................] - ETA: 9:25 - loss: 32.5690  master-replica-0
Adding visible gpu devices: 0  master-replica-0
Device interconnect StreamExecutor with strength 1 edge matrix:  master-replica-0
  0   master-replica-0
  0:   N   master-replica-0
Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10763 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7)  master-replica-0

 15/100 [===>..........................] - ETA: 9:00 - loss: 32.9770  master-replica-0

Поэтому я пытаюсь:

  • Создать уникальный сеансна Генераторе __init__, но он не работает ...
  • Используйте keras.backend.get_session() как уникальный сеанс, но не работает ...

В каждомслучай, когда я получил ошибку:

Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=() dtype=string> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(), dtype=string) is not an element of this graph.)

Как я могу использовать уникальную tf.session для моего пакетного генератора?

1 Ответ

0 голосов
/ 25 сентября 2018

Проблема была в сессиях.Открыв несколько сеансов во время обучения, вынуждает систему создавать новые графические процессоры.

Решение состоит в том, чтобы переместить with tf.Session() as session до while True.

.
...