TensorFlow 2-tf.keras: Как обучить многозадачную сеть tf.keras, такую ​​как MTCNN, используя API tf.data и TFRecords - PullRequest
0 голосов
/ 31 марта 2020

В последнее время я пытаюсь использовать TFRecords для обучения модели tf.keras. Поскольку для TensorFlow 2 наиболее эффективным способом является использование API tf.data, я стараюсь использовать его для обучения своей модели MTCNN Keras. Но вот что меня смущает:

Согласно оригинальной статье, разные сэмплы (pos, neg, part-face, landmark) участвуют в разных частях тренинга. И каждый тип выборок имеет определенное соотношение c в каждой мини-партии, т.е. для одной мини-партии крыса ios для образцов pos, neg, part-face и landmark должна составлять 1: 3: 1: 2.

Так, например, после того, как я сделал четыре TFRecords сэмплов, мне нужно будет взять 128 сэмплов из pos tfrecords, 384 из neg, 128 из part-face и 256 из landmark tfrecords и смешать их в одну мини-партию. А потом мне нужно перемешать партию перед тренировкой.

Я действительно не знаю, как это сделать при использовании TFRecords и API tf.data. Теперь я могу выполнить все эти шаги, только читая изображения и ярлыки, но это слишком медленно для обучения. Поэтому мне интересно, есть ли эффективный способ сделать это.

Любой совет приветствуется.

Обновлено 2020.04.04 15: 38
Благодаря @AAudibert, я думаю, что его / ее ответ работает очень хорошо, и я также нашел один способ реализовать это. Вот код для справки:

raw_pos_dataset = tf.data.TFRecordDataset(POS_TFRECORDS_PATH_LIST)
raw_neg_dataset = tf.data.TFRecordDataset(NEG_TFRECORDS_PATH_LIST)
raw_part_dataset = tf.data.TFRecordDataset(PART_TFRECORDS_PATH_LIST)
raw_landmark_dataset = tf.data.TFRecordDataset(LANDMARK_TFRECORDS_PATH_LIST)

image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'info': tf.io.FixedLenFeature([17], tf.float32),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
    }

def _read_tfrecord(serialized_example):

    example = tf.io.parse_single_example(serialized_example, image_feature_description)

    img = tf.image.decode_jpeg(example['image_raw'], channels = 3) # RGB rather than BGR!!! 
    img = (tf.cast(img, tf.float32) - 127.5) / 128.
    img_shape = [example['height'], example['width'], example['depth']]
    img = tf.reshape(img, img_shape)

    info = example['info']

    return img, info

parsed_pos_dataset = raw_pos_dataset.map(_read_tfrecord)
parsed_neg_dataset = raw_neg_dataset.map(_read_tfrecord)
parsed_part_dataset = raw_part_dataset.map(_read_tfrecord)
parsed_landmark_dataset = raw_landmark_dataset.map(_read_tfrecord)

parsed_image_dataset = tf.data.Dataset.zip((parsed_pos_dataset.repeat().shuffle(16384).batch(int(BATCH_SIZE * DATA_COMPOSE_RATIO[0])), 
                                            parsed_neg_dataset.repeat().shuffle(16384).batch(int(BATCH_SIZE * DATA_COMPOSE_RATIO[1])), 
                                            parsed_part_dataset.repeat().shuffle(16384).batch(int(BATCH_SIZE * DATA_COMPOSE_RATIO[2])), 
                                            parsed_landmark_dataset.repeat().shuffle(16384).batch(int(BATCH_SIZE * DATA_COMPOSE_RATIO[3]))))

def concatenate(pos_info, neg_info, part_info, landmark_info):

    img_tensor = tf.zeros((0, IMG_SIZE, IMG_SIZE, 3), dtype = tf.float32)
    label_tensor = tf.zeros((0, 17), dtype = tf.float32)
    pos_img = pos_info[0]
    neg_img = neg_info[0]
    part_img = part_info[0]
    landmark_img = landmark_info[0]
    pos_info = pos_info[1]
    neg_info = neg_info[1]
    part_info = part_info[1]
    landmark_info = landmark_info[1]
    img_tensor = tf.concat([img_tensor, pos_img, neg_img, part_img, landmark_img], axis = 0)
    info_tensor = tf.concat([label_tensor, pos_info, neg_info, part_info, landmark_info], axis = 0)

    return img_tensor, info_tensor

ds = parsed_image_dataset.map(concatenate)

1 Ответ

0 голосов
/ 03 апреля 2020

Вы можете сделать выборку в соответствии с указанной крысой ios, используя sample_from_datasets

См. этот colab для примера. Я также скопировал код ниже.

import tensorflow as tf

batch_size = 8

pos = tf.data.Dataset.range(0, 100)
neg = tf.data.Dataset.range(100, 200)
part_face = tf.data.Dataset.range(200, 300)
landmark = tf.data.Dataset.range(300, 400)

dataset = tf.data.experimental.sample_from_datasets(
    [pos, neg, part_face, landmark], [1/7, 3/7, 1/7, 2/7])
dataset = dataset.batch(batch_size)

# Optionally shuffle data further. Samples will already be interspersed between datasets.
# dataset = dataset.map(lambda batch: tf.random.shuffle(batch))

for elem in dataset:
  print(elem.numpy())

# Prints
[200 300 100 201 301 302 101 303]
[  0 304 202 102 203 103 305 104]
[306 307 105 204 308 205 206   1]
[207 309 106 107 310 108 311 312]
[208 209 210   2 109 211 110 212]
...
...