Набор данных tenorflow from_generator () выходит за пределы допустимого диапазона - PullRequest
1 голос
/ 08 апреля 2020

Я пытаюсь использовать tf.data.Dataset.from_generator() для генерации данных обучения и проверки.

У меня есть собственный генератор данных, который выполняет подготовку функций на лету:

def data_iterator(self, input_file_list, ...):
    for f in input_file_list:
        X, y = get_feature(f)
        yield X, y

Первоначально я кормил это напрямую к модели tenorflow keras, но после первой партии я столкнулся с ошибкой данных вне диапазона. Тогда я решил обернуть это в генератор данных tenorflow:

train_gen = lambda: data_iterator(train_files, ...)
valid_gen = lambda: data_iterator(valid_files, ...)

output_types = (tf.float32, tf.float32)
output_shapes = (tf.TensorShape([499, 13]), tf.TensorShape([2]))
train_dat = tf.data.Dataset.from_generator(train_gen,
                                           output_types=output_types,
                                           output_shapes=output_shapes)
valid_dat = tf.data.Dataset.from_generator(valid_gen,
                                           output_types=output_types,
                                           output_shapes=output_shapes)
train_dat = train_dat.repeat().batch(batch_size=128)
valid_dat = valid_dat.repeat().batch(batch_size=128)

Затем подойдет:

model.fit(x=train_dat,
          validation_data=valid_dat,
          steps_per_epoch=train_steps,
          validation_steps=valid_steps,
          epochs=100,
          callbacks=callbacks)

Однако, я все еще получаю ошибку, несмотря на наличие .repeat() в генераторе:

BaseCollectiveExecutor :: StartAbort Вне диапазона: конец последовательности

Мой вопрос:

  • почему .repeat() здесь не работает ?
  • я должен добавить while True в свой собственный итератор, чтобы избежать этого? Я чувствую, что это может исправить это, но это не похоже на правильный способ сделать это.

1 Ответ

1 голос
/ 09 апреля 2020

Я добавил некоторое время True в свой собственный генератор, чтобы он никогда не заканчивался, и я больше не получаю ошибку:

def data_iterator(self, input_file_list, ...):
    while True;
        for f in input_file_list:
            X, y = get_feature(f)
            yield X, y

Однако я не знаю, почему .repeat() не работает на .from_generator()

...