Я тренируюсь с использованием тензорного потока tfrecords, вот методы / функции, которые обрабатывают данные:
функция чтения:
def read_tfr(tf_record_file, classes_file, feature_map, max_boxes,
classes_delimiter='\n', new_size=None):
text_init = tf.lookup.TextFileInitializer(
classes_file, tf.string, 0, tf.int64, -1, delimiter=classes_delimiter)
class_table = tf.lookup.StaticHashTable(text_init, -1)
files = tf.data.Dataset.list_files(tf_record_file)
dataset = files.flat_map(tf.data.TFRecordDataset)
default_logger.info(f'Read TFRecord: {tf_record_file}')
return dataset.map(
lambda x: read_example(x, feature_map, class_table, max_boxes, new_size))
набор данных:
def initialize_dataset(self, tf_record, batch_size):
dataset = read_tfr(tf_record, self.classes_file, get_feature_map(), self.max_boxes)
dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda x, y: (
transform_images(x, self.input_shape[0]),
transform_targets(y, self.anchors, self.masks, self.input_shape[0])))
dataset = dataset.prefetch(
buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
затем я звоню model.fit()
следующим образом:
history = self.training_model.fit(training_dataset,
epochs=epochs,
callbacks=callbacks,
validation_data=valid_dataset)
Это прекрасно работает для небольших наборов данных, однако, когда я тренируюсь на больших наборах данных, обучение рано или поздно прекращается убийцей оом. поэтому я подумал об использовании генератора вместо этого ... вот что я попробовал:
После изменения метода:
def initialize_dataset(self, tf_record, batch_size):
dataset = read_tfr(tf_record, self.classes_file, get_feature_map(), self.max_boxes)
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda x, y: (
transform_images(x, self.input_shape[0]),
transform_targets(y, self.anchors, self.masks, self.input_shape[0])))
dataset = dataset.prefetch(
buffer_size=tf.data.experimental.AUTOTUNE)
image, label = next(iter(dataset.take(1)))
while True:
yield image, label
и я вызываю model.fit()
примерно так:
training_size = sum([1 for _ in tf.data.TFRecordDataset(self.train_tf_record)])
history = self.training_model.fit(training_gen,
epochs=epochs,
callbacks=callbacks,
validation_data=valid_gen,
steps_per_epoch=training_size / batch_size)
Тренировка останавливается после первой эпохи:
2/45 [>.............................] - ETA: 23s - loss: 9690.3174 - layer_207_output_0_loss: 470.6727 - layer_234_output_1_loss: 1492.7981 - layer_261_output_2_loss: 7715.1030WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time. Check your callbacks.
46/45 [==============================] - ETA: 0s - loss: 4130.6221 - layer_207_output_0_loss: 95.8419 - layer_234_output_1_loss: 458.6847 - layer_261_output_2_loss: 3564.34792020-04-28 04:37:58,844 session_log.read_tfr +177: INFO [7901] Read TFRecord: ../Data/TFRecords/beverly_hills_test.tfrecord
Если я не укажу steps_per_epoch
, эпоха будет проходить после 45 шагов и идти навсегда (до 2000 шагов в одно из испытаний без ошибок)
Итак, я попытался заменить while True
в генераторе на while start < num_of_steps
, он работает, и после первой эпохи модель проверяется и сохраняется, а затем я получаю ошибку StopIteration.
Как я могу сделать эту работу с Ge nerator? Я не хочу передавать набор данных в целом в model.fit
(это работает нормально, пока что-то не выходит из памяти)