Мне кажется, вы без необходимости усложняете свою жизнь с генератором.
Вот как я бы реализовал ваш входной конвейер:
def parse_file_tf(filename):
return tf.py_func(parse_file, [filename], [tf.float32, tf.float32])
# version with map
files = tf.data.Dataset.from_tensor_slices(files_to_process)
dataset = files.map(parse_file_tf, num_parallel_calls=N)
dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
dataset = dataset.batch(batch_size).shuffle(shuffle_size).prefetch(2)
it = dataset.make_one_shot_iterator()
Чтобы проверить это, я определяю пустышку parse_file
как:
i=0
def parse_file(f):
global i
i += 1
return np.asarray([i]*i, dtype=np.float32), np.asarray([i]*i, dtype=np.float32) # mimicks variable-length examples_x, examples_y
, который я передаю в базовый цикл, который показывает, что возвращает итератор:
sess = tf.Session()
try:
while True:
x, y = it.get_next()
vx, vy = sess.run([x,y])
print(vx)
print(vy)
except tf.errors.OutOfRangeError:
pass
sess.close()
Выполнение кода выше печатает:
[2. 3. 2. 1. 3. 3.]
[2. 3. 2. 1. 3. 3.]
Объяснение трубопровода
По сути, я оставляю проблему распараллеливания на map
, где я могу передать количество потоков, которые он должен запустить. Нет необходимости в генераторах, повторяющихся по диапазонам, и этим дополнительным сложностям.
Я выбрал map вместо parallel_interleave
, потому что последний требует, чтобы вы генерировали экземпляр Dataset
для каждого возвращаемого элемента, что в вашем случае не имеет смысла, поскольку вы уже загрузили все значения в память при запуске parse_file
.
parallel_interleave
имеет смысл, если вы медленно генерируете значения (например, применяя tf.data.TFRecordDataset
к списку имен файлов), но если ваш набор данных помещается в память, выберите map
.
Об ограничениях tf.py_func
, они не влияют на вашу обученную сеть, только на входной конвейер. В идеале у вас должен быть другой канал для обучения и окончательного использования сети. Вам нужно только позаботиться об ограничениях во время последнего, в то время как для обучения (если вы не делаете что-то очень специфическое с распределенным обучением и / или перемещением обучения между машинами) вы достаточно безопасны.
Версия с генератором
Если ваши JSON-файлы очень большие и их содержимое не помещается в памяти, вы можете использовать генератор, но немного отличающийся от подхода, с которого вы начали.
Идея в том, что генератор просматривает JSON-файл и yield
s по одной записи за раз. Тогда генератор должен быть вашей parse_file
функцией. В качестве примера предположим, что у вас есть следующий генератор parse_file
:
i = 3
def parse_file(filename):
global i
i += 1
ctr = 0
while ctr < i:
yield ctr, ctr
В этом случае конвейер будет выглядеть следующим образом:
def wrap_generator(filename):
return tf.data.Dataset.from_generator(parse_file(filename), [tf.int32, tf.int32])
files = tf.data.Dataset.from_tensor_slices(files_to_process)
dataset = files.apply(tf.contrib.data.parallel_interleave(wrap_generator, cycle_length=N))
dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
dataset = dataset.shuffle(shuffle_size).batch(batch_size).prefetch(2)
it = dataset.make_one_shot_iterator()
Обратите внимание, что здесь нам нужно использовать parallel_interleave
, потому что мы превращаем генераторы в Dataset
экземпляры, из которых мы извлекаем значения.
Остальное остается прежним.
Подача этого в тот же цикл выборки, что и выше:
[6. 5. 4. 4. 6. 5. 6. 6. 5. 4. 6. 4. 5. 5. 6.]
[6. 5. 4. 4. 6. 5. 6. 6. 5. 4. 6. 4. 5. 5. 6.]