Как я могу объединить ImageDataGenerator с наборами данных TensorFlow в TF2? - PullRequest
3 голосов
/ 08 января 2020

У меня есть набор данных TF для классификации кошек и собак:

import tensorflow_datasets as tfds
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs', split=list(splits),
    with_info=True, as_supervised=True)

В примере они используют некоторое увеличение изображения с помощью функции карты. Мне было интересно, можно ли это сделать и с классом ImageDataGenerator, как описано здесь :

from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_image_generator = ImageDataGenerator(rescale=1./255) # Generator for our training data
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode='binary')

Проблема, с которой я сталкиваюсь, заключается в том, что я могу видеть только 3 способа использовать ImageDataGenerator: pandas массив данных, numpy массив и каталог изображений. Есть ли способ также использовать набор данных Tensorflow и объединить эти методы?

1 Ответ

1 голос
/ 09 января 2020

Да, но это немного сложно.
Keras ImageDataGenerator работает на numpy.array с, а не на tf.Tensor, поэтому мы должны использовать Tensorflow's numpy_function . Это позволит нам выполнять операции с tf.data.Dataset содержимым, как это было numpy массивов.

Сначала давайте объявим функцию, которую мы будем .map над нашим набором данных (предполагая, что ваш набор данных состоит из изображения, пары меток):

# We will take 1 original image and create 5 augmented images:
HOW_MANY_TO_AUGMENT = 5

def augment(image, label):

  # Create generator and fit it to an image
  img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
  img_gen.fit(image)

  # We want to keep original image and label
  img_results = [(image/255.).astype(np.float32)] 
  label_results = [label]

  # Perform augmentation and keep the labels
  augmented_images = [next(img_gen.flow(image)) for _ in range(HOW_MANY_TO_AUGMENT)]
  labels = [label for _ in range(HOW_MANY_TO_AUGMENT)]

  # Append augmented data and labels to original data
  img_results.extend(augmented_images)
  label_results.extend(labels)

  return img_results, label_results

Теперь, чтобы использовать эту функцию внутри tf.data.Dataset, мы должны объявить numpy_function:

def py_augment(image, label):
  func = tf.numpy_function(augment, [image, label], [tf.float32, tf.int32])
  return func

py_augment, который можно безопасно использовать как :

augmented_dataset_ds = image_label_dataset.map(py_augment)

Часть image в наборе данных теперь имеет форму (HOW_MANY_TO_AUGMENT, image_height, image_width, channels). Чтобы преобразовать его в простой (1, image_height, image_width, channels), вы можете просто использовать unbatch:

unbatched_augmented_dataset_ds = augmented_dataset_ds.unbatch()

Таким образом, весь раздел выглядит так:

HOW_MANY_TO_AUGMENT = 5

def augment(image, label):

  # Create generator and fit it to an image
  img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
  img_gen.fit(image)

  # We want to keep original image and label
  img_results = [(image/255.).astype(np.float32)] 
  label_results = [label]

  # Perform augmentation and keep the labels
  augmented_images = [next(img_gen.flow(image)) for _ in range(HOW_MANY_TO_AUGMENT)]
  labels = [label for _ in range(HOW_MANY_TO_AUGMENT)]

  # Append augmented data and labels to original data
  img_results.extend(augmented_images)
  label_results.extend(labels)

  return img_results, label_results

def py_augment(image, label):
  func = tf.numpy_function(augment, [image, label], [tf.float32, tf.int32])
  return func

unbatched_augmented_dataset_ds = augmented_dataset_ds.map(py_augment).unbatch()

# Iterate over the dataset for preview:
for image, label in unbatched_augmented_dataset_ds:
    ...
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...