Я пытаюсь использовать tf.data.Dataset.from_generator()
для генерации данных обучения и проверки.
У меня есть собственный генератор данных, который выполняет подготовку функций на лету:
def data_iterator(self, input_file_list, ...):
for f in input_file_list:
X, y = get_feature(f)
yield X, y
Первоначально я кормил это напрямую к модели tenorflow keras, но после первой партии я столкнулся с ошибкой данных вне диапазона. Тогда я решил обернуть это в генератор данных tenorflow:
train_gen = lambda: data_iterator(train_files, ...)
valid_gen = lambda: data_iterator(valid_files, ...)
output_types = (tf.float32, tf.float32)
output_shapes = (tf.TensorShape([499, 13]), tf.TensorShape([2]))
train_dat = tf.data.Dataset.from_generator(train_gen,
output_types=output_types,
output_shapes=output_shapes)
valid_dat = tf.data.Dataset.from_generator(valid_gen,
output_types=output_types,
output_shapes=output_shapes)
train_dat = train_dat.repeat().batch(batch_size=128)
valid_dat = valid_dat.repeat().batch(batch_size=128)
Затем подойдет:
model.fit(x=train_dat,
validation_data=valid_dat,
steps_per_epoch=train_steps,
validation_steps=valid_steps,
epochs=100,
callbacks=callbacks)
Однако, я все еще получаю ошибку, несмотря на наличие .repeat()
в генераторе:
BaseCollectiveExecutor :: StartAbort Вне диапазона: конец последовательности
Мой вопрос:
- почему
.repeat()
здесь не работает ? - я должен добавить
while True
в свой собственный итератор, чтобы избежать этого? Я чувствую, что это может исправить это, но это не похоже на правильный способ сделать это.