Я загружаю поток данных из генератора, используя tf.data.Dataset.from_generator
, затем увеличиваю данные и объединяю как исходные, так и дополненные данные для увеличения набора данных.Генератор вызывается для каждого связанного вызова, как его избежать?
Я написал следующий код:
import tensorflow as tf
tf.enable_eager_execution()
def generator():
# some heavy fn to load data points form stream (generator)
print("CALL generator")
for i in range(5): # simulation of data stream
print("generator iteration: ", str(i))
yield i
if __name__ == '__main__':
data = tf.data.Dataset.from_generator(generator, ( tf.int32))
dataset = data
# augment data points and add augmented version to dataset to have both original data points and augmented data
# points
data = data.concatenate(
dataset.map(lambda x: x * 2))
iterator = data.batch(10).prefetch(10).make_one_shot_iterator()
for im in iterator:
print(im)
Я ожидаю:
CALL generator
generator iteration: 0
generator iteration: 1
generator iteration: 2
generator iteration: 3
generator iteration: 4
tf.Tensor([0 1 2 3 4 0 2 4 6 8], shape=(10,), dtype=int32)
но я получаю:
CALL generator
generator iteration: 0
generator iteration: 1
generator iteration: 2
generator iteration: 3
generator iteration: 4
CALL generator
generator iteration: 0
generator iteration: 1
generator iteration: 2
generator iteration: 3
generator iteration: 4
tf.Tensor([0 1 2 3 4 0 2 4 6 8], shape=(10,), dtype=int32)