Я совершенно новичок в Keras / Tensorflow. Ниже мой fit_generator
train_dataset=train_fn_inputs(batch_size, None)
val_data=validation_fn_inputs(batch_size, None)
total_records = 44712
val_records = 11178
steps_per_epoch=int(total_records // batch_size)
hist=model.fit_generator(#aug.flow(X_def, y_def, batch_size=batch_size),
#get_batches(X_def, y_def, batch_size),
train_dataset,
steps_per_epoch=steps_per_epoch, #(training_df.shape[0])//batchsize,
epochs=5,
verbose = 1,
#callbacks=[early_stopping],
#validation_data=val_data,
#validation_steps=val_records//batch_size,
workers=0
)
И определения функций:
def train_fn_inputs(bs, aug=None):
train_files, total_records = get_training_data_old()
steps_per_epoch = int(total_records / batch_size)
raw_dataset = tf.data.TFRecordDataset(train_files) #.repeat()
parsed_image_dataset = raw_dataset.map(_parse_image_function).repeat().shuffle(buffer_size=buf_size).batch(batch_size).make_initializable_iterator()
image, label = parsed_image_dataset.get_next()
image = tf.reshape(image, [3, IMG_WIDTH, IMG_HEIGHT, bs])
#label = tf.reshape(label, [bs, 75, 25])
while True:
yield (np.array(image), np.array(label))
Но я получаю эту ошибку:
Файл "... \ Anaconda3\ lib \ site-packages \ tenorflow \ python \ keras \ engine \ training_generator.py ", строка 184, в model_iteration batch_size = int (nest.flatten (batch_data) [0] .shape [0])
IndexError: индекс кортежа вне диапазона