Проблема с компонентом TFX Trainer, не выводящим модель в файловую систему - PullRequest
0 голосов
/ 09 апреля 2020

Прежде всего, я использую TFX версии 0.21.2 и Tensorflow версии 2.1.

Я построил конвейер в значительной степени по примеру такси Чигако. Когда компонент Trainer выполнен, я вижу в журналах следующее:

INFO - Обучение завершено. Модель записана в / root / airflow / tfx / pipelines / fish / Trainer / model / 9 / serve_model_dir

При проверке вышеуказанного каталога он пуст. Чего мне не хватает?

Это мой файл определения DAG (операторы импорта опущены):

_pipeline_name = 'fish'
_airflow_config = AirflowPipelineConfig(airflow_dag_config = {
    'schedule_interval': None,
    'start_date': datetime.datetime(2019, 1, 1),
})
_project_root = os.path.join(os.environ['HOME'], 'airflow')
_data_root = os.path.join(_project_root, 'data', 'fish_data')
_module_file = os.path.join(_project_root, 'dags', 'fishUtils.py')
_serving_model_dir = os.path.join(_project_root, 'serving_model', _pipeline_name)
_tfx_root = os.path.join(_project_root, 'tfx')
_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name)
_metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name,
                              'metadata.db')


def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     metadata_path: Text,
                     direct_num_workers: int) -> pipeline.Pipeline:

    examples = external_input(data_root)
    example_gen = CsvExampleGen(input=examples)

    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    infer_schema = SchemaGen(
      statistics=statistics_gen.outputs['statistics'],
      infer_feature_shape=False)

    validate_stats = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=infer_schema.outputs['schema'])

    trainer = Trainer(
    examples=example_gen.outputs['examples'], schema=infer_schema.outputs['schema'],
    module_file=_module_file, train_args= trainer_pb2.TrainArgs(num_steps=10000),
    eval_args= trainer_pb2.EvalArgs(num_steps=5000))

    model_validator = ModelValidator(
      examples=example_gen.outputs['examples'],
      model=trainer.outputs['model'])

    pusher = Pusher(
      model=trainer.outputs['model'],
      model_blessing=model_validator.outputs['blessing'],
      push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
          base_directory=_serving_model_dir)))

    return pipeline.Pipeline(
      pipeline_name=_pipeline_name,
      pipeline_root=_pipeline_root,
      components=[
          example_gen,
          statistics_gen,
          infer_schema,
          validate_stats,
          trainer,
          model_validator,
          pusher],
      enable_cache=True,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers]
  )

runner = AirflowDagRunner(config = _airflow_config)
DAG = runner.run(
    _create_pipeline(
        pipeline_name=_pipeline_name,
        pipeline_root=_pipeline_root,
        data_root=_data_root,
        module_file=_module_file,
        serving_model_dir=_serving_model_dir,
        metadata_path=_metadata_path,
        # 0 means auto-detect based on on the number of CPUs available during
        # execution time.
        direct_num_workers=0))

А это мой файл модуля:

_DENSE_FLOAT_FEATURE_KEYS = ['length']

real_valued_columns = [tf.feature_column.numeric_column('length')]

def _eval_input_receiver_fn():

  serialized_tf_example = tf.compat.v1.placeholder(
      dtype=tf.string, shape=[None], name='input_example_tensor')

  features = tf.io.parse_example(
      serialized=serialized_tf_example,
      features={
          'length': tf.io.FixedLenFeature([], tf.float32),
          'label': tf.io.FixedLenFeature([], tf.int64),
      })

  receiver_tensors = {'examples': serialized_tf_example}

  return tfma.export.EvalInputReceiver(
      features={'length' : features['length']},
      receiver_tensors=receiver_tensors,
      labels= features['label'],
      )

def parser(serialized_example):

  features = tf.io.parse_single_example(
      serialized_example,
      features={
          'length': tf.io.FixedLenFeature([], tf.float32),
          'label': tf.io.FixedLenFeature([], tf.int64),
      })
  return ({'length' : features['length']}, features['label'])

def _input_fn(filenames):
  # TFRecordDataset doesn't directly accept paths with wildcards
  filenames = tf.data.Dataset.list_files(filenames)
  dataset = tf.data.TFRecordDataset(filenames, 'GZIP')
  dataset = dataset.map(parser)
  dataset = dataset.shuffle(2000)
  dataset = dataset.batch(40)
  dataset = dataset.repeat(10)

  return dataset

def trainer_fn(trainer_fn_args, schema):

    estimator = tf.estimator.LinearClassifier(feature_columns=real_valued_columns)

    train_input_fn = lambda: _input_fn(trainer_fn_args.train_files)

    train_spec = tf.estimator.TrainSpec(
      train_input_fn,
      max_steps=trainer_fn_args.train_steps)

    eval_input_fn = lambda: _input_fn(trainer_fn_args.eval_files)

    eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=trainer_fn_args.eval_steps,
      name='fish-eval')

    receiver_fn = lambda: _eval_input_receiver_fn()

    return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }

Спасибо заранее за вашу помощь!

1 Ответ

0 голосов
/ 15 апреля 2020

Публикация решения для любого, кто сталкивается с той же проблемой, с которой я столкнулся.

Причина, по которой модель не была записана в файловой системе, заключалась в том, что оценщику необходим аргумент конфигурации, чтобы знать, куда писать модель .

Следующая модификация функции trainer_fn должна решить проблему.

run_config = tf.estimator.RunConfig(save_checkpoints_steps=999, keep_checkpoint_max=1)  

run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)

estimator=tf.estimator.LinearClassifier(feature_columns=real_valued_columns,config=run_config)
...