Ошибка Dataset.from_generator, когда X и метки являются кортежами - PullRequest
0 голосов
/ 18 февраля 2020

Я пытаюсь создать tf.data.Dataset из X, labels из функции генератора в ordet, чтобы передать его в tf.keras.Model.fit.

Моя модель имеет несколько входов и выходов (все изображения / маски, подобные массивы). Упрощенная версия генератора выглядит следующим образом:

def generator():
    sample = get_single_sample()
    while True:
       yield ((sample[0], sample[1]), (sample[2], sample[3]))


dataset = tf.data.Dataset.from_generator(gen, output_types=((tf.float32,)*2,)*2)

Как упоминалось ранее, массив sample содержит 4 массива с плавающей точкой numpy.

Я не могу этого сделать работа:

  • как должна выглядеть output_types? Я пытался output_types=((tf.float32,)*2,)*2 и получаю в результате ошибку: ValueError: as_list() is not defined on an unknown TensorShape
  • я должен указать output_shapes в моем случае? ( документация говорит, что это необязательно)
...