Как установить контрольную точку для тонкой настройки - PullRequest
0 голосов
/ 24 мая 2019

Я обнаружил, что потеря при переобучении модели (ssd_mobilenetv2) из ​​model_zoo очень велика в начале обучения, в то время как точность validation_set хорошая. Журнал тренировок, как показано ниже:

Журнал не может быть от обученной модели. Я сомневаюсь, что это не загружает контрольную точку для точной настройки. Пожалуйста, помогите мне, как выполнить точную настройку с обученной моделью в том же наборе данных. Я вообще не изменял структуру сети.

Я установил путь к контрольной точке в pipe.config, как показано ниже: fine_tune_checkpoint: " / / ssd_mobilenet_v2_coco_2018_03_29 / model.ckpt" Если я установлю model_dir в качестве загруженного каталога, он не будет работать, так как global_train_step больше, чем max_step. Затем я увеличиваю max_step, я вижу журнал восстановления параметра из контрольной точки. Но это встретило бы ошибку, которая не могла восстановить некоторый параметр Поэтому я установил model_dir в пустой каталог. Он мог бы тренироваться нормально, но потеря в step0 была бы очень большой. И результат проверки очень плохой

в pipe.config

fine_tune_checkpoint: "/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt"
num_steps: 200000
fine_tune_checkpoint_type: "detection"

сценарий поезда

model_dir = '/ssd_mobilenet_v2_coco_2018_03_29/retrain0524

pipeline_config_path = '/ssd_mobilenet_v2_coco_2018_03_29/pipeline.config'

checkpoint_dir = '/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt'

num_train_steps = 300000
config = tf.estimator.RunConfig(model_dir=model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
    run_config=config,
    hparams=model_hparams.create_hparams(hparams_overrides),
    pipeline_config_path=pipeline_config_path,    
    sample_1_of_n_eval_examples=sample_1_of_n_eval_examples,
    sample_1_of_n_eval_on_train_examples=(sample_1_of_n_eval_on_train_examples))
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fns = train_and_eval_dict['eval_input_fns']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']

train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fns,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_on_train_data=False)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])

ИНФОРМАЦИЯ: тензор потока: потери = 356.25497, шаг = 0 ИНФОРМАЦИЯ: тензор потока: global_step / sec: 1.89768 ИНФОРМАЦИЯ: тензор потока: потери = 11,221423, шаг = 100 (52,700 с) ИНФОРМАЦИЯ: тензор потока: global_step / sec: 2.21685 ИНФОРМАЦИЯ: тензор потока: потери = 10,329516, шаг = 200 (45,109 с)

1 Ответ

0 голосов
/ 27 мая 2019

Если начальная потеря тренировки равна 400, модель, скорее всего, будет успешно восстановлена ​​с контрольной точки, но не совсем так же, как контрольная точка.

Здесь - это функция restore_map моделей ssd, обратите внимание, что даже если вы установите fine_tune_checkpoint_type : detection и даже обеспечены точно такой же контрольной точкой той же модели, все еще только переменные в feature_extractor сфера восстановлена. Чтобы восстановить как можно больше переменных из контрольной точки, вам нужно установить load_all_detection_checkpoint_vars: true в вашем конфигурационном файле.

def restore_map(self,
              fine_tune_checkpoint_type='detection',
              load_all_detection_checkpoint_vars=False):

if fine_tune_checkpoint_type not in ['detection', 'classification']:
  raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
      fine_tune_checkpoint_type))

if fine_tune_checkpoint_type == 'classification':
  return self._feature_extractor.restore_from_classification_checkpoint_fn(
      self._extract_features_scope)

if fine_tune_checkpoint_type == 'detection':
  variables_to_restore = {}
  for variable in tf.global_variables():
    var_name = variable.op.name
    if load_all_detection_checkpoint_vars:
      variables_to_restore[var_name] = variable
    else:
      if var_name.startswith(self._extract_features_scope):
        variables_to_restore[var_name] = variable

return variables_to_restore
...