Я пытаюсь создать tf.data.Dataset
из X, labels
из функции генератора в ordet, чтобы передать его в tf.keras.Model.fit
.
Моя модель имеет несколько входов и выходов (все изображения / маски, подобные массивы). Упрощенная версия генератора выглядит следующим образом:
def generator():
sample = get_single_sample()
while True:
yield ((sample[0], sample[1]), (sample[2], sample[3]))
dataset = tf.data.Dataset.from_generator(gen, output_types=((tf.float32,)*2,)*2)
Как упоминалось ранее, массив sample
содержит 4 массива с плавающей точкой numpy.
Я не могу этого сделать работа:
- как должна выглядеть
output_types
? Я пытался output_types=((tf.float32,)*2,)*2
и получаю в результате ошибку: ValueError: as_list() is not defined on an unknown TensorShape
- я должен указать
output_shapes
в моем случае? ( документация говорит, что это необязательно)