Правильный способ перебора tf.data.Dataset в сессии для 2.0 - PullRequest
0 голосов
/ 31 мая 2019

Я скачал некоторые данные *.tfrecord из проекта youtube-8m .Вы можете загрузить «небольшую» часть данных с помощью этой команды:

curl data.yt8m.org/download.py | shard=1,100 partition=2/video/train mirror=us python

Я пытаюсь понять, как использовать новый API tf.data.Я хотел бы познакомиться с типичными способами, которыми люди перебирают наборы данных.Я использовал руководство на веб-сайте TF и ​​этот слайд: Слайды Дерека Мюррея

Вот как я определяю набор данных:

# Use interleave() and prefetch() to read many files concurrently.
files = tf.data.Dataset.list_files("./youtube_vids/*.tfrecord")
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(100),
                           cycle_length=8)

# Use num_parallel_calls to parallelize map().
dataset = dataset.map(lambda record: tf.parse_single_example(record, feature_map),
                     num_parallel_calls=2) #

# put in x,y output form
dataset = dataset.map(lambda x: (x['mean_rgb'], x['id']))

# shuffle
dataset = dataset.shuffle(10000)

#one epoch
dataset = dataset.repeat(1)
dataset = dataset.batch(200)

#Use prefetch() to overlap the producer and consumer.
dataset = dataset.prefetch(10)

Теперь я знаю вВ нетерпеливом режиме выполнения я могу просто

for x,y in dataset:
    x,y

Однако, когда я пытаюсь создать итератор следующим образом:

# A one-shot iterator automatically initializes itself on first use.
iterator = dset.make_one_shot_iterator()

# The return value of get_next() matches the dataset element type.
images, labels = iterator.get_next()

И запустить с сеансом

with tf.Session() as sess:

    # Loop until all elements have been consumed.
    try:
        while True:
            r = sess.run(images)
    except tf.errors.OutOfRangeError:
        pass

Iполучите предупреждение

Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.

Итак, вот мой вопрос:

Как правильно перебрать набор данных в сеансе?Это просто вопрос различий между v1 и v2?

Кроме того, совет передать набор данных непосредственно в оценщик подразумевает, что у входной функции также есть итератор, определенный как на слайдах Дерека Мюррея выше, верно?

1 Ответ

2 голосов
/ 01 июня 2019

Что касается Estimator API, нет необходимости указывать итератор, просто передайте объект набора данных в качестве входной функции.

def input_fn(filename):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.shuffle().repeat()
    dataset = dataset.map(parse_func)
    dataset = dataset.batch()
    return dataset

estimator.train(input_fn=lambda: input_fn())

В TF 2.0 набор данных стал итеративным, поэтому, как сказано в предупреждающем сообщении, вы можете использовать

for x,y in dataset:
    x,y
...