Невозможно сохранить модель при использовании стандартного оценщика Tensorflow Lattice - PullRequest
0 голосов
/ 18 июня 2020

Я пытаюсь запустить руководство Tensorflow Lattice Canned Estimators: https://www.tensorflow.org/lattice/tutorials/canned_estimators#calibrated_lattice_model

, когда доходит до этой строки:

saved_model_path = estimator.export_saved_model(estimator.model_dir,
                                                serving_input_fn)

Я получаю следующая ошибка:

>TypeError                                 Traceback (most recent call last)
<ipython-input-21-1d0ce82d7b37> in <module>
     19 print('Calibrated linear test AUC: {}'.format(results['auc']))
     20 saved_model_path = estimator.export_saved_model(estimator.model_dir,
---> 21                                                 serving_input_fn)
     22 model_graph = tfl.estimators.get_model_graph(saved_model_path)
     23 tfl.visualization.draw_model_graph(model_graph)
>
>~\Anaconda3\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py in export_saved_model(self, export_dir_base, serving_input_receiver_fn, assets_extra, as_text, checkpoint_path, experimental_mode)
    726         as_text=as_text,
    727         checkpoint_path=checkpoint_path,
--> 728         strip_default_attrs=True)
    729 
    730   def experimental_export_all_saved_models(self,
>
~\Anaconda3\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py in _export_all_saved_models(self, export_dir_base, input_receiver_fn_map, assets_extra, as_text, checkpoint_path, strip_default_attrs)
    864             save_variables,
    865             mode=ModeKeys.PREDICT,
--> 866             strip_default_attrs=strip_default_attrs)
    867         save_variables = False
    868 
>
~\Anaconda3\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py in _add_meta_graph_for_mode(self, builder, input_receiver_fn_map, checkpoint_path, save_variables, mode, export_tags, check_variables, strip_default_attrs)
    932       tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)
    933 
--> 934       input_receiver = input_receiver_fn()
    935 
    936       # Call the model_fn and collect the export_outputs.
>
~\Anaconda3\lib\site-packages\tensorflow_estimator\python\estimator\export\export.py in serving_input_receiver_fn()
    308     receiver_tensors = {'examples': serialized_tf_example}
    309     features = tf.compat.v1.io.parse_example(serialized_tf_example,
--> 310                                              feature_spec)
    311     return ServingInputReceiver(features, receiver_tensors)
    312 
>
~\Anaconda3\lib\site-packages\tensorflow\python\ops\parsing_ops.py in parse_example(serialized, features, name, example_names)
    316 @tf_export(v1=["io.parse_example", "parse_example"])
    317 def parse_example(serialized, features, name=None, example_names=None):
--> 318   return parse_example_v2(serialized, features, example_names, name)
    319 
    320 
>
~\Anaconda3\lib\site-packages\tensorflow\python\ops\parsing_ops.py in parse_example_v2(serialized, features, example_names, name)
    310   ])
    311 
--> 312   outputs = _parse_example_raw(serialized, example_names, params, name=name)
    313   return _construct_tensors_for_composite_features(features, outputs)
    314 
>
~\Anaconda3\lib\site-packages\tensorflow\python\ops\parsing_ops.py in _parse_example_raw(serialized, names, params, name)
    357         ragged_split_types=params.ragged_split_types,
    358         dense_shapes=params.dense_shapes_as_proto,
--> 359         name=name)
    360     (sparse_indices, sparse_values, sparse_shapes, dense_values,
    361      ragged_values, ragged_row_splits) = outputs
>
~\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_parsing_ops.py in parse_example_v2(serialized, names, sparse_keys, dense_keys, ragged_keys, dense_defaults, num_sparse, sparse_types, ragged_value_types, ragged_split_types, dense_shapes, name)
    761                           ragged_value_types=ragged_value_types,
    762                           ragged_split_types=ragged_split_types,
--> 763                           dense_shapes=dense_shapes, name=name)
    764   _result = _outputs[:]
    765   if _execute.must_record_gradient():
>
~\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
    693       elif attr_def.type == "list(type)":
    694         attr_value.list.type.extend(
--> 695             [_MakeType(x, attr_def) for x in value])
    696       elif attr_def.type == "shape":
    697         attr_value.shape.CopyFrom(_MakeShape(value, key))
>
~\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py in <listcomp>(.0)
    693       elif attr_def.type == "list(type)":
    694         attr_value.list.type.extend(
--> 695             [_MakeType(x, attr_def) for x in value])
    696       elif attr_def.type == "shape":
    697         attr_value.shape.CopyFrom(_MakeShape(value, key))
>
~\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py in _MakeType(v, attr_def)
    178                     (attr_def.name, repr(v)))
    179   i = v.as_datatype_enum
--> 180   _SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name)
    181   return i
    182 
>
~\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py in _SatisfiesTypeConstraint(dtype, attr_def, param_name)
     59           "allowed values: %s" %
     60           (param_name, dtypes.as_dtype(dtype).name,
---> 61            ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
     62 
     63 
>

TypeError: значение, переданное параметру sparse_types, имеет DataType int32 не в списке допустимых значений: float32, int64, string '

1 Ответ

1 голос
/ 11 июля 2020

Мне не удалось воспроизвести ошибку в среде выполнения Google Colab, предоставленную TensorFlow по умолчанию. Не могли бы вы дать ему еще один шанс?

Если вы работаете в настраиваемой среде выполнения: сообщение об ошибке предполагает, что тип данных для разреженных функций - другими словами, функций, определенных categorical_column_with_vocabulary_list - int32, который не поддерживается. Вы можете решить эту проблему, указав аргумент dtype для этих целочисленных разреженных функций, например

fc.categorical_column_with_vocabulary_list('sex', [0, 1], dtype=tf.int64)
...