Невозможно прочитать прогнозы после вызова estimator.predict в тензорном потоке - PullRequest
0 голосов
/ 15 января 2019

Я новичок в tenorflow и пытаюсь использовать следующий пример кода для запуска функции прогнозирования.

https://www.tensorflow.org/tutorials/sequences/recurrent_quickdraw#loss_predictions_and_optimizer

с тензорным потоком 1,12

Я поместил следующую строку в model_fn (функции, метки, режим, параметры):

# Set None value for predict mode 
if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = tf.argmax(logits, axis=1)
    return tf.estimator.EstimatorSpec(
         mode=mode,
         predictions={"logits": logits, "predictions": predictions})

вызовите следующую функцию, чтобы запустить прогноз

def run_predict(self, estimator):
result = estimator.predict(
        input_fn=get_input_fn(mode=tf.estimator.ModeKeys.PREDICT,
        batch_size= g_config['batch_size']))

      print(next(result))
      return result 

def predict(self):
    estimator = self.create_estimator(
          run_config = tf.estimator.RunConfig(
          model_dir = self.model_path))

    result = self.run_predict(estimator)

    print(next(result))

но я получил эти ошибки

Traceback (most recent call last):
File "learn_dat.py", line 116, in <module>
main(sys.argv)
File "learn_dat.py", line 105, in main
rnn.predict_data(league, season, date_str)
File "/root/work/breg_new/lib/prdc/rnn_model.py", line 438, in predict_data
model.predict()
File "/root/work/breg_new/lib/prdc/rnn_model.py", line 141, in predict
result = self.run_predict(estimator)
File "/root/work/breg_new/lib/prdc/rnn_model.py", line 105, in run_predict
print(next(result))
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 549, in predict
input_fn, model_fn_lib.ModeKeys.PREDICT)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 1024, in _get_features_from_input_fn
result = self._call_input_fn(input_fn, mode)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 1136, in _call_input_fn
return input_fn(**kwargs)
File "/root/work/breg_new/lib/prdc/rnn_model.py", line 257, in _input_fn
num_parallel_calls=10)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1007, in map
return ParallelMapDataset(self, map_func, num_parallel_calls)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 2248, in __init__
super(ParallelMapDataset, self).__init__(input_dataset, map_func)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 2216, in __init__
map_func, "Dataset.map()", input_dataset)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1473, in __init__
self._function.add_to_graph(ops.get_default_graph())
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 479, in add_to_graph
self._create_definition_if_needed()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 335, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 344, in _create_definition_if_needed_impl
self._capture_by_value, self._caller_device)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 865, in func_graph_from_py_func
outputs = func(*func_graph.inputs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1456, in tf_data_structured_function_wrapper
"%s: %s." % (transformation_name, t))
TypeError: Unsupported return value from function passed to Dataset.map(): None.
...