Я использую keras ImageDataGenerator для дополнения моих данных. Но я хочу применить дополнительные пользовательские преобразования к дополненным изображениям. Я знаю, что ImageDataGenerator принимает функцию preprocessing_function, которая будет делать именно это, за исключением того, что мои дополнительные преобразования требуют ввода как изображения, так и истинности земли, в то время как функция preprocessing_function принимает только одно изображение.
Я реализовал это в очень громоздкий способ (код ниже), и мой вопрос, если нет лучшего способа сделать это. В качестве дополнительных преобразований я пороговую маску (истинность земли) и применяю функцию дополнения, которая принимает некоторые параметры в качестве входных данных, а также изображение вместе с маской.
image_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
mask_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
image_val_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
mask_val_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
image_generator = image_datagen.flow(x_train, seed=seed)
mask_generator = mask_datagen.flow(y_train, seed=seed)
image_val_generator = image_val_datagen.flow(x_val, seed=seed + 1)
mask_val_generator = mask_val_datagen.flow(y_val, seed=seed + 1)
imgs = [next(image_generator) for _ in range(1000)]
masks = [np.where(next(mask_generator) > 0.5, 1, 0).astype('float32') for _ in range(1000)] #because keras datagumentation interpolates the data, a threshold must be taken to make the data binary again
imgs_val = [next(image_val_generator) for _ in range(1000)]
masks_val = [np.where(next(mask_val_generator) > 0.5, 1, 0).astype('float32') for _ in range(1000)]
imgs = np.concatenate(imgs)
masks = np.concatenate(masks)
imgs_val = np.concatenate(imgs_val)
masks_val = np.concatenate(masks_val)
for i in range(imgs.shape[0]):
imgs[i] = augment(imgs[i], masks[i], brightness_range = data_gen_args['brightness_range'], noise_var_range = data_gen_args['noise_var_range'], bias_var_range = data_gen_args['bias_var_range'])
for i in range(imgs_val.shape[0]):
imgs_val[i] = augment(imgs_val[i], masks_val[i], brightness_range = data_gen_args['brightness_range'], noise_var_range = data_gen_args['noise_var_range'], bias_var_range = data_gen_args['bias_var_range'])
train_dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(imgs), tf.data.Dataset.from_tensor_slices(masks)))
train_dataset = train_dataset.repeat().shuffle(1000).batch(32)
validation_set = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(imgs_val), tf.data.Dataset.from_tensor_slices(masks_val)))
validation_set = validation_set.repeat().shuffle(1000).batch(32)