Перестановка / добавление данных с помощью TensorFlow Dataset API - PullRequest
0 голосов
/ 04 ноября 2018

В последние несколько дней я пытался ознакомиться с API TensorFlow Dataset API. Моя цель - создать рабочий проход, который передает ImageNet моим моделям.

Я реализовал некоторый код, который работает, и при этом он читает N осколков файлов TFRecord. Однако мне трудно это делать. Есть несколько вопросов ...

  1. Я хочу перетасовать последовательность фрагментов так, чтобы данные как бы перетасовывались через весь набор данных. Я пытался это с помощью методов, представленных в https://github.com/tensorflow/tensorflow/issues/14857. Однако, когда я пытаюсь использовать tf.data.Dataset.list_files (filenames) , я получаю следующую ошибку

Сбой OP_REQUIRES в example_parsing_ops.cc:144: Неверный аргумент: Не удалось проанализировать пример ввода, значение: 'val / validation-00033-of-00128'

  1. Чтобы уменьшить задержку, я попытался использовать num_parallel_calls . Однако, похоже, что это не делает его быстрее (в настоящее время я использую i7 8700K). Кроме того, когда я пытаюсь увеличить параметр для предварительной выборки или случайного воспроизведения, это, кажется, сильно замедляет его. Я делаю что-то неправильно? или неправильный порядок использования API набора данных?

  2. Я надеюсь построить это так, чтобы его можно было использовать для обучения моделей, и я хочу реализовать расширение данных в функции decode_for_train , которая будет копией decode_for_eval функция. Однако я не знаю как. Один из методов, который я рассмотрел, заключается в добавлении строки, которая просто делает [224,224,3] к [?, 224,224,3]. Тем не менее, я обеспокоен тем, что это приведет к выводу запущенного итератора из [224,224,3] в [?, 224,224,3]. Есть ли такой умный способ сделать это?

Спасибо за помощь!


В настоящее время мой набор данных преобразован в TFRecords:

val / validation-00000-of-00128 to val / validation-00128-of-00128

Мой текущий код такой.

def decode_for_eval(example):
    image = image_ops.decode_jpeg(example, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.image.central_crop(image, central_fraction=0.875)
    image = tf.expand_dims(image, 0)
    image = tf.image.resize_bilinear(image, [224, 224], align_corners=False)
    image = tf.squeeze(image, [0])
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    return image

def decode(serialized_example):
    features = tf.parse_example(
        serialized_example,
        features={
            'image/encoded': tf.FixedLenFeature([], tf.string, default_value=''),
            'image/format': tf.FixedLenFeature([], tf.string, default_value='jpeg'),
            'image/width': tf.FixedLenFeature([], tf.int64),
            'image/height': tf.FixedLenFeature([], tf.int64),
            'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=-1),
            'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/class/label': tf.VarLenFeature(dtype=tf.int64),
        }
    )
    width = tf.cast(features['image/width'], tf.int32)
    height = tf.cast(features['image/height'], tf.int32)

    ### EVAL ###
    image = tf.map_fn(lambda x: decode_for_eval(x), features['image/encoded'], dtype=tf.float32)

    ### TRAIN ###
    # I want to build this!

    label = tf.cast(features['image/class/label'], tf.int32)

    return image, label

def get_iterator(filenames, batch_size):
    with tf.name_scope('input'):
        dataset = tf.data.TFRecordDataset(filenames)
        dataset = dataset.repeat(None)
        dataset = dataset.shuffle(5000)
        dataset = dataset.batch(batch_size)
        dataset = dataset.map(decode, num_parallel_calls=4)
        dataset = dataset.prefetch(batch_size)
        iterator = dataset.make_initializable_iterator()
    return iterator

def main():
    sess = tf.Session()

    # get iterator for single file
    dataset_dir = 'val/validation-*'
    filenames = glob.glob(dataset_dir)
    iterator = get_iterator(filenames, 128)
    sess.run(iterator.initializer)
    image, label = iterator.get_next()

    images, labels = sess.run([image, label])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...