Как создать фиксированную длину tf.Dataset из генератора? - PullRequest
1 голос
/ 09 июня 2019

У меня есть генератор, который выдает бесконечное количество данных (Случайное изображение кадрирования). Я хотел бы создать tf.Dataset на основе, скажем, 10000 первых точек данных и кэшировать его, чтобы использовать их для обучения моделей?

В настоящее время у меня есть генератор, который занимает 1-2 секунды, чтобы создать каждую точку данных, и это основной блокатор производительности. Мне нужно подождать минуту, чтобы сгенерировать пакет из 64 изображений (функция preprocessing() очень дорогая, поэтому я хотел бы повторно использовать результаты).

ds = tf.Dataset.from_generator() метод позволяет нам создать такой бесконечный набор данных. Вместо этого я хотел бы создать конечный набор данных, используя N первых выходов генератора, и кэшировать его следующим образом:

ds = ds.cache().


Альтернативное решение - продолжать генерировать новые данные и использовать кэшированные точки данных при рендеринге генератора.

1 Ответ

1 голос
/ 09 июня 2019

Вы можете использовать функцию Dataset.cache с функцией Dataset.take для достижения этой цели.

Если все умещается в памяти, это так же просто, как сделать что-то вроде этого:

def generate_example():
  i = 0
  while(True):
    print ('yielding value {}'.format(i))
    yield tf.random.uniform((64,64,3))
    i +=1

ds = tf.data.Dataset.from_generator(generate_example, tf.float32)

first_n_datapoints = ds.take(n).cache()

Теперь обратите внимание, что если я установлю n на 3, скажем, сделайте что-нибудь тривиальное, например:

for i in first_n_datapoints.repeat():
  print ('')
  print (i.shape)

, тогда я увижу вывод, подтверждающий, что первые 3 значения кэшируются (я вижу вывод yielding value {i} только один раздля каждого из первых 3 сгенерированных значений:

yielding value 0
(64,64,3)
yielding value 1
(64,64,3)
yielding value 2
(64,64,3)
(64,64,3)
(64,64,3)
(64,64,3)
...

Если все не умещается в памяти, мы можем передать путь к файлу в функцию кеширования, где она будет кешировать сгенерированные тензоры на диск.

Больше информации здесь: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache

...