Использование генератора, который выдает tf.dataset для model.fit - PullRequest
0 голосов
/ 15 января 2020

Я хочу записать keras imageDataGenerator в конвейер tf.dataset. Код в основном взят из учебника по Loadimage со страницы tf. Я хочу иметь генератор, который извлекает данные и делает некоторые дополнения изображения, так что я получаю разные наборы тренировочных данных в каждую эпоху.

def process_path(file_path, augment=False):
 label = get_label(file_path)
 img = tf.io.read_file(file_path)
 img = decode_img(img, augment_function=augment)
 return img, label

def decode_img(img, augment_function=False):

 img = tf.image.decode_jpeg(img, channels=1)
 img = tf.image.convert_image_dtype(img, tf.float32)
 img= tf.image.flip_left_right(img)
 img=tf.image.rot90(img, k=tf.random.uniform(shape=[], minval=1,
                                            maxval=4, dtype=tf.int32))
 img= tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT],
                     method='gaussian',
                    antialias=True)

 return img

def count(list_ds=data):
 inter=list_ds.map(process_path,num_parallel_calls=4)
 train_d, _, _=split_data(inter,inter, ratio=0.85)
 yield train_d

ds_counter = tf.data.Dataset.from_generator(count, output_types=(tf.float32,  tf.bool),
                                        output_shapes=((None, 64, 64, 1), (None, 15)))

train_d - это мои пакетные тренировочные данные в tf Формат набора данных:

train_d
<ConcatenateDataset shapes: ((None, 64, 64, 1), (None, 15)), types: (tf.float32, tf.bool)>

Вызов ds_counter возвращает:

 <FlatMapDataset shapes: ((None, 64, 64, 1), (None, 15)), types: (tf.float32, tf.bool)>

Это выглядит правильно, у меня 15 меток классов.

his=model.fit(ds_counter,validation_data=ds_counter, epochs=500,callbacks=callbacklist, verbose=0)

бросков:

InvalidArgumentError: TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.bool), but the yielded element was <BatchDataset shapes: ((None, 64, 64, 1), (None, 15)), types: (tf.float32, tf.bool)>.

Вероятно, я мог бы избежать этого, написав генератор для получения массивов, но я бы хотел придерживаться методов tf daset, которые мне очень нравятся.

Большое спасибо

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...