Как правильно использовать export_savedmodel в тензорном потоке? - PullRequest
0 голосов
/ 12 июня 2018

Я работаю с Tensorflow 1.8.

Я создал пользовательский tf.estimator my_estimator, и после обучения я хочу экспортировать свою модель, чтобы использовать ее для прогнозирования.Для этого я попытался (feature_placeholders являются входами моей модели):

feature_placeholders = {
        "Numerical_features": tf.placeholder(tf.float32, [None, None, parameters.model_params['N_INPUT']]),
        "Categorical_features": tf.placeholder(tf.int32,
                                               [None, None, len(parameters.model_params['vocabulary_sizes'].keys())]),
        "Fixed_features": tf.placeholder(tf.float32, [None]),
        "Lengths_features": tf.placeholder(tf.int32, [None]),
        "labels": tf.placeholder(tf.float32, [None]),
        "Predictions": tf.placeholder(tf.float32, [None])
                                }

my_estimator.export_savedmodel('my_directory',
        serving_input_receiver_fn=tf.estimator.export.build_raw_serving_input_receiver_fn(
        features=feature_placeholders)

Я получаю следующую ошибку:

 File "/home/train_eval_predict.py", line 650, in train_rnn
    features=feature_placeholders)
  File "/home/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 613, in export_savedmodel
    config=self.config)
  File "/home/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 831, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/home/train_eval_predict.py", line 418, in model_fn
    logits=prediction))
  File "/home/tensorflow/python/ops/nn_ops.py", line 1829, in softmax_cross_entropy_with_logits_v2
    logits)
  File "/home/tensorflow/python/ops/nn_ops.py", line 1777, in _ensure_xent_args
    raise ValueError("Both labels and logits must be provided.")
ValueError: Both labels and logits must be provided.

Как это исправить?

Кстати, я не уверен, что feature_placeholders правильно определен относительно tf.estimator.build_raw_serving_input_receiver_fn().

...