API набора данных Tensorflow: распараллеливание tf.data.Dataset.from_generator с параллельным_интерлейвом - PullRequest
0 голосов
/ 11 мая 2018

В производственной среде у меня есть данные, поступающие от N производителей, которые должны пройти через сеть.Я нашел этот комментарий о распараллеливании tf.data.Dataset.from_generator , который действительно описывает то, что я хочу.

def generator(n):
  # returns n-th generator function

def dataset(n):
  return tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))

# where N is the number of generators you use

Однако как должна выглядеть функция generator (n).Потому что, когда я запускаю этот пример с

 def generator(n):
        """Returns the n-th generator function (for consumer n)
        """
        consumer = self.consumers[n]

        def gen():
            for item in consumer:
                yield item

        return gen

с self.consumers списком Python, я получаю ошибку:

TypeError: индексы списка должны быть целыми или кусочками, а неТензор

1 Ответ

0 голосов
/ 23 апреля 2019

Реализация почти правильная, но вы получаете ошибку, потому что аргумент n в dataset(n) является "символическим" tf.Tensor, а не фактическим значением, которое можно использоватьпоиск потребителя в self.consumers.

К счастью, существует обходной путь, который включает передачу n через необязательный аргумент args в tf.data.Dataset.from_generator():

def dataset(n):
  return tf.data.Dataset.from_generator(generator, args=(n,))

Под прикрытием from_generator() вставляет некоторый код для преобразования n в целое число Python перед каждым вызовом generator.

...