Примените функцию предварительной обработки к Keras ImageDataGenerator - PullRequest
0 голосов
/ 15 апреля 2020

Я использую keras ImageDataGenerator для дополнения моих данных. Но я хочу применить дополнительные пользовательские преобразования к дополненным изображениям. Я знаю, что ImageDataGenerator принимает функцию preprocessing_function, которая будет делать именно это, за исключением того, что мои дополнительные преобразования требуют ввода как изображения, так и истинности земли, в то время как функция preprocessing_function принимает только одно изображение.

Я реализовал это в очень громоздкий способ (код ниже), и мой вопрос, если нет лучшего способа сделать это. В качестве дополнительных преобразований я пороговую маску (истинность земли) и применяю функцию дополнения, которая принимает некоторые параметры в качестве входных данных, а также изображение вместе с маской.

    image_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
    mask_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
    image_val_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
    mask_val_datagen = kp.image.ImageDataGenerator(**data_gen_args_)

    image_generator = image_datagen.flow(x_train, seed=seed)
    mask_generator = mask_datagen.flow(y_train, seed=seed)
    image_val_generator = image_val_datagen.flow(x_val, seed=seed + 1)
    mask_val_generator = mask_val_datagen.flow(y_val, seed=seed + 1)

    imgs = [next(image_generator) for _ in range(1000)]
    masks = [np.where(next(mask_generator) > 0.5, 1, 0).astype('float32') for _ in range(1000)]   #because keras datagumentation interpolates the data, a threshold must be taken to make the data binary again
    imgs_val = [next(image_val_generator) for _ in range(1000)]
    masks_val = [np.where(next(mask_val_generator) > 0.5, 1, 0).astype('float32') for _ in range(1000)]

    imgs = np.concatenate(imgs)
    masks = np.concatenate(masks)
    imgs_val = np.concatenate(imgs_val)
    masks_val = np.concatenate(masks_val)

    for i in range(imgs.shape[0]):
        imgs[i] = augment(imgs[i], masks[i], brightness_range = data_gen_args['brightness_range'], noise_var_range = data_gen_args['noise_var_range'], bias_var_range = data_gen_args['bias_var_range'])
    for i in range(imgs_val.shape[0]):
        imgs_val[i] = augment(imgs_val[i], masks_val[i], brightness_range = data_gen_args['brightness_range'], noise_var_range = data_gen_args['noise_var_range'], bias_var_range = data_gen_args['bias_var_range'])

    train_dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(imgs), tf.data.Dataset.from_tensor_slices(masks)))
    train_dataset = train_dataset.repeat().shuffle(1000).batch(32)
    validation_set = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(imgs_val), tf.data.Dataset.from_tensor_slices(masks_val)))
    validation_set = validation_set.repeat().shuffle(1000).batch(32)
...