Я хочу записать keras imageDataGenerator в конвейер tf.dataset. Код в основном взят из учебника по Loadimage со страницы tf. Я хочу иметь генератор, который извлекает данные и делает некоторые дополнения изображения, так что я получаю разные наборы тренировочных данных в каждую эпоху.
def process_path(file_path, augment=False):
label = get_label(file_path)
img = tf.io.read_file(file_path)
img = decode_img(img, augment_function=augment)
return img, label
def decode_img(img, augment_function=False):
img = tf.image.decode_jpeg(img, channels=1)
img = tf.image.convert_image_dtype(img, tf.float32)
img= tf.image.flip_left_right(img)
img=tf.image.rot90(img, k=tf.random.uniform(shape=[], minval=1,
maxval=4, dtype=tf.int32))
img= tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT],
method='gaussian',
antialias=True)
return img
def count(list_ds=data):
inter=list_ds.map(process_path,num_parallel_calls=4)
train_d, _, _=split_data(inter,inter, ratio=0.85)
yield train_d
ds_counter = tf.data.Dataset.from_generator(count, output_types=(tf.float32, tf.bool),
output_shapes=((None, 64, 64, 1), (None, 15)))
train_d - это мои пакетные тренировочные данные в tf Формат набора данных:
train_d
<ConcatenateDataset shapes: ((None, 64, 64, 1), (None, 15)), types: (tf.float32, tf.bool)>
Вызов ds_counter возвращает:
<FlatMapDataset shapes: ((None, 64, 64, 1), (None, 15)), types: (tf.float32, tf.bool)>
Это выглядит правильно, у меня 15 меток классов.
his=model.fit(ds_counter,validation_data=ds_counter, epochs=500,callbacks=callbacklist, verbose=0)
бросков:
InvalidArgumentError: TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.bool), but the yielded element was <BatchDataset shapes: ((None, 64, 64, 1), (None, 15)), types: (tf.float32, tf.bool)>.
Вероятно, я мог бы избежать этого, написав генератор для получения массивов, но я бы хотел придерживаться методов tf daset, которые мне очень нравятся.
Большое спасибо