как избавиться от итератора набора данных tenorflow и функции предварительной обработки из графа тензорного потока - PullRequest
0 голосов
/ 06 февраля 2019

Я перешел по ссылке Тонкая настройка VGG16 , чтобы обучить mobilenetV1, обучение выглядит хорошо.Затем я заморозил график и посмотрел на график на тензорной доске. Я обнаружил, что график будет включать в себя итератор набора данных, который используется для итерации по набору обучающих данных, а также функции предварительной обработки для обучения и проверки.

Чтобы использовать этозамороженный граф для вывода, мне не нужна обучающая функция предварительной обработки и итератор набора данных, но, поскольку они являются тензорными операциями, я не знаю, как удалить их из замороженного графа, кто-нибудь может помочь?Или я не должен использовать итераторы tf.data.Dataset для загрузки пакета данных, а вместо этого использовать массивы numpy?

graph = tf.Graph()
with graph.as_default():

        train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))
    train_dataset = train_dataset.map(_parse_function,
        num_parallel_calls=args.num_workers).prefetch(args.batch_size)
    train_dataset = train_dataset.map(training_preprocess,
        num_parallel_calls=args.num_workers).prefetch(args.batch_size)
    train_dataset = train_dataset.shuffle(buffer_size=10000)  # don't forget to shuffle
    batched_train_dataset = train_dataset.batch(args.batch_size)

    # Validation dataset
    val_dataset = tf.data.Dataset.from_tensor_slices((val_filenames, val_labels))
    val_dataset = val_dataset.map(_parse_function,
        num_parallel_calls=args.num_workers).prefetch(args.batch_size)
    val_dataset = val_dataset.map(val_preprocess,
        num_parallel_calls=args.num_workers).prefetch(args.batch_size)
    batched_val_dataset = val_dataset.batch(args.batch_size)

    iterator = tf.data.Iterator.from_structure(batched_train_dataset.output_types,
                                                       batched_train_dataset.output_shapes)
    images, labels = iterator.get_next()

Это изображение показывает график высокого уровня, одной из структур данных справа являетсяобучающая функция предварительной обработки, итератор является входом для модели мобильной сети.

...