TensorFlow: многопоточная распаковка наборов данных - PullRequest
0 голосов
/ 11 июля 2019

Я использую TensorFlow 2.0 beta.У меня есть TensorFlow Dataset, где каждый элемент представляет собой набор столбцов объектов: кортеж тензоров, каждый из которых имеет значения определенной функции для batch_size записей.Мне нужно сгладить эти записи для сериализации как TFRecords, что я хотел бы сделать, используя функции TensorFlow Dataset.Сглаженные записи не нужно создавать в детерминированном порядке.

Вот пример кода, демонстрирующего то, что я пытаюсь выполнить:

batch_size = 100
num_batches = 10
input_data = (tf.constant(['text_data']), tf.constant(13))
ds = tf.data.Dataset.from_tensors(input_data).repeat(batch_size * num_batches)
ds = ds.batch(batch_size)
# ds = ... (multithreaded data transformations on batches of records happen here)
ds = ds.unbatch()

Проблема в том, что методы, которые япытались сделать это либо не работают, либо образуют серьезные узкие места, потому что они являются однопоточными.Вот некоторые из этих подходов:

  1. unbatch - однопоточный, слишком медленный
  2. interleave / flat_map - flat_map не принимает на кортежитензор - "принимает 2 позиционных аргумента, но задано" [num_features] ""
  3. interleave / пользовательская функция с py_function - не работает, потому что py_function не может вернуть Dataset
  4. interleave / пользовательская функция без py_function - не работает, потому что в графическом режиме не может перебирать тензоры

Мне нужно каким-то образом заменить unbatchраспределения пакетов по нескольким потокам, которые независимо друг от друга распаковывают, а затем чередуют результаты из разных потоков.Есть идеи?

1 Ответ

0 голосов
/ 11 июля 2019

Вот версия, которую я в конечном итоге нашел, используя interleave с from_tensor_slices:

batch_size = 100
num_batches = 10
num_threads = 4
input_data = (tf.constant(['text_data']), tf.constant(13))
ds = tf.data.Dataset.from_tensors(input_data).repeat(batch_size * num_batches)
ds = ds.batch(batch_size)
# ds = ... (multithreaded data transformations on batches of records happen here)
ds = ds.interleave(lambda *args:tf.data.Dataset.from_tensor_slices(args), num_threads, 1, num_threads)
...