В последнее время я пытаюсь использовать TFRecords для обучения модели tf.keras. Поскольку для TensorFlow 2 наиболее эффективным способом является использование API tf.data, я стараюсь использовать его для обучения своей модели MTCNN Keras. Но вот что меня смущает:
Согласно оригинальной статье, разные сэмплы (pos, neg, part-face, landmark) участвуют в разных частях тренинга. И каждый тип выборок имеет определенное соотношение c в каждой мини-партии, т.е. для одной мини-партии крыса ios для образцов pos, neg, part-face и landmark должна составлять 1: 3: 1: 2.
Так, например, после того, как я сделал четыре TFRecords сэмплов, мне нужно будет взять 128 сэмплов из pos tfrecords, 384 из neg, 128 из part-face и 256 из landmark tfrecords и смешать их в одну мини-партию. А потом мне нужно перемешать партию перед тренировкой.
Я действительно не знаю, как это сделать при использовании TFRecords и API tf.data. Теперь я могу выполнить все эти шаги, только читая изображения и ярлыки, но это слишком медленно для обучения. Поэтому мне интересно, есть ли эффективный способ сделать это.
Любой совет приветствуется.
Обновлено 2020.04.04 15: 38
Благодаря @AAudibert, я думаю, что его / ее ответ работает очень хорошо, и я также нашел один способ реализовать это. Вот код для справки:
raw_pos_dataset = tf.data.TFRecordDataset(POS_TFRECORDS_PATH_LIST)
raw_neg_dataset = tf.data.TFRecordDataset(NEG_TFRECORDS_PATH_LIST)
raw_part_dataset = tf.data.TFRecordDataset(PART_TFRECORDS_PATH_LIST)
raw_landmark_dataset = tf.data.TFRecordDataset(LANDMARK_TFRECORDS_PATH_LIST)
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'info': tf.io.FixedLenFeature([17], tf.float32),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
def _read_tfrecord(serialized_example):
example = tf.io.parse_single_example(serialized_example, image_feature_description)
img = tf.image.decode_jpeg(example['image_raw'], channels = 3) # RGB rather than BGR!!!
img = (tf.cast(img, tf.float32) - 127.5) / 128.
img_shape = [example['height'], example['width'], example['depth']]
img = tf.reshape(img, img_shape)
info = example['info']
return img, info
parsed_pos_dataset = raw_pos_dataset.map(_read_tfrecord)
parsed_neg_dataset = raw_neg_dataset.map(_read_tfrecord)
parsed_part_dataset = raw_part_dataset.map(_read_tfrecord)
parsed_landmark_dataset = raw_landmark_dataset.map(_read_tfrecord)
parsed_image_dataset = tf.data.Dataset.zip((parsed_pos_dataset.repeat().shuffle(16384).batch(int(BATCH_SIZE * DATA_COMPOSE_RATIO[0])),
parsed_neg_dataset.repeat().shuffle(16384).batch(int(BATCH_SIZE * DATA_COMPOSE_RATIO[1])),
parsed_part_dataset.repeat().shuffle(16384).batch(int(BATCH_SIZE * DATA_COMPOSE_RATIO[2])),
parsed_landmark_dataset.repeat().shuffle(16384).batch(int(BATCH_SIZE * DATA_COMPOSE_RATIO[3]))))
def concatenate(pos_info, neg_info, part_info, landmark_info):
img_tensor = tf.zeros((0, IMG_SIZE, IMG_SIZE, 3), dtype = tf.float32)
label_tensor = tf.zeros((0, 17), dtype = tf.float32)
pos_img = pos_info[0]
neg_img = neg_info[0]
part_img = part_info[0]
landmark_img = landmark_info[0]
pos_info = pos_info[1]
neg_info = neg_info[1]
part_info = part_info[1]
landmark_info = landmark_info[1]
img_tensor = tf.concat([img_tensor, pos_img, neg_img, part_img, landmark_img], axis = 0)
info_tensor = tf.concat([label_tensor, pos_info, neg_info, part_info, landmark_info], axis = 0)
return img_tensor, info_tensor
ds = parsed_image_dataset.map(concatenate)