tf.estimator.BoostedTreesRegressor SavedModel Проблема восстановления - PullRequest
0 голосов
/ 20 июня 2019

У меня проблема с восстановлением модели tf.estimator.BoostedTreesRegressor с использованием tf.SavedModel.При перезагрузке модели из каталога сохраненной модели с помощью tf.contrib.predictor.from_saved_model () я получаю следующую ошибку:

KeyError: «Имя« boosted_trees / QuantileAccumulator / »относится к операции, нена графике. "

Эта ошибка возникает только при использовании числовых функций (например, tf.feature_column.numeric_column).Перезагрузка модели работает нормально при использовании только категориальных столбцов

Когда я не сохраняю / не восстанавливаю, BoostedTreesRegressor оценивает и прогнозирует успешно со всеми функциями.

Следующие сценарии оценки / восстановления успешно работали:
- DNNRegressor с числовыми и категориальными функциями
- LinearRegressor с числовыми и категориальными функциями
- BoostedTreeRegressor с только категориальными функциями

fc = tf.feature_column
feature_columns = [
fc.numeric_column('f1', dtype=tf.int64),
fc.numeric_column('f2', dtype=tf.int64),
fc.indicator_column(
               fc.categorical_column_with_vocabulary_list('f3',f3)),
fc.indicator_column(
               fc.categorical_column_with_vocabulary_list('f4',f4))
]

feature_spec = fc.make_parse_example_spec(feature_columns)

params = {
    'feature_columns' : feature_columns,
    'n_batches_per_layer' : n_batches,
    'n_trees': 200,
    'max_depth': 6,
    'learning_rate': 0.01
}

regressor = tf.estimator.BoostedTreesRegressor(**params)
regressor.train(train_input_fn, max_steps=400)

serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

regressor.export_saved_model('saved_model', serving_input_receiver_fn)

.
.
.
# latest is path to saved model
predict_fn = predictor.from_saved_model(latest[:-4])
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-101-ee20beae4424> in <module>
----> 1 predict_fn = predictor.from_saved_model(latest[:-4])
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/contrib/predictor/predictor_factories.py in from_saved_model(export_dir, signature_def_key, signature_def, input_names, output_names, tags, graph, config)
    151       tags=tags,
    152       graph=graph,
--> 153       config=config)
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/contrib/predictor/saved_model_predictor.py in __init__(self, export_dir, signature_def_key, signature_def, input_names, output_names, tags, graph, config)
    151     with self._graph.as_default():
    152       self._session = session.Session(config=config)
--> 153       loader.load(self._session, tags.split(','), export_dir)
    154 
    155     if input_names is None:
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load(sess, tags, export_dir, import_scope, **saver_kwargs)
    267   """
    268   loader = SavedModelLoader(export_dir)
--> 269   return loader.load(sess, tags, import_scope, **saver_kwargs)
    270 
    271 
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load(self, sess, tags, import_scope, **saver_kwargs)
    418     with sess.graph.as_default():
    419       saver, _ = self.load_graph(sess.graph, tags, import_scope,
--> 420                                  **saver_kwargs)
    421       self.restore_variables(sess, saver, import_scope)
    422       self.run_init_ops(sess, tags, import_scope)
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load_graph(self, graph, tags, import_scope, **saver_kwargs)
    348     with graph.as_default():
    349       return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
--> 350           meta_graph_def, import_scope=import_scope, **saver_kwargs)
    351 
    352   def restore_variables(self, sess, saver, import_scope=None):
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/training/saver.py in _import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope, return_elements, **kwargs)
   1455           import_scope=import_scope,
   1456           return_elements=return_elements,
-> 1457           **kwargs))
   1458 
   1459   saver = _create_saver_from_imported_meta_graph(
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/meta_graph.py in import_scoped_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name, restore_collections_predicate, return_elements)
    850           for value in field.value:
    851             col_op = graph.as_graph_element(
--> 852                 ops.prepend_name_scope(value, scope_to_prepend_to_names))
    853             graph.add_to_collection(key, col_op)
    854         elif kind == "int64_list":
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
   3476 
   3477     with self._lock:
-> 3478       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3479 
   3480   def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
   3536         if name not in self._nodes_by_name:
   3537           raise KeyError("The name %s refers to an Operation not in the "
-> 3538                          "graph." % repr(name))
   3539         return self._nodes_by_name[name]
   3540 
KeyError: "The name 'boosted_trees/QuantileAccumulator/' refers to an Operation not in the graph."
...