Вы можете использовать функцию Dataset.cache
с функцией Dataset.take
для достижения этой цели.
Если все умещается в памяти, это так же просто, как сделать что-то вроде этого:
def generate_example():
i = 0
while(True):
print ('yielding value {}'.format(i))
yield tf.random.uniform((64,64,3))
i +=1
ds = tf.data.Dataset.from_generator(generate_example, tf.float32)
first_n_datapoints = ds.take(n).cache()
Теперь обратите внимание, что если я установлю n
на 3, скажем, сделайте что-нибудь тривиальное, например:
for i in first_n_datapoints.repeat():
print ('')
print (i.shape)
, тогда я увижу вывод, подтверждающий, что первые 3 значения кэшируются (я вижу вывод yielding value {i}
только один раздля каждого из первых 3 сгенерированных значений:
yielding value 0
(64,64,3)
yielding value 1
(64,64,3)
yielding value 2
(64,64,3)
(64,64,3)
(64,64,3)
(64,64,3)
...
Если все не умещается в памяти, мы можем передать путь к файлу в функцию кеширования, где она будет кешировать сгенерированные тензоры на диск.
Больше информации здесь: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache