Я пытался использовать InMemoryEvaluatorHook с TPUEstimator для получения статистики проверки во время обучения моей модели.Использование цикла estimator.train()
и estimator.evaluate()
было слишком дорогим, поскольку он перестраивал график каждую эпоху, вместо того, чтобы пытаться использовать его повторно (как указано в этом выпуске: https://github.com/tensorflow/tensorflow/issues/13895). Это основной код, который я использую:
estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
config=run_config,
use_tpu=True,
train_batch_size=self.batch_size,
eval_batch_size=self.batch_size,
predict_batch_size=self.batch_size,
params={})
train_fn = lambda params: input_fn(
'train', self.data_dir, batch_size=params['batch_size'], train=True)
val_fn = lambda params: input_fn(
'validation',
self.data_dir,
batch_size=params['batch_size'],
train=False)
train_hook = tf.contrib.estimator.InMemoryEvaluatorHook(
estimator,
val_fn,
steps=self.steps_per_val_epoch,
every_n_iter=self.steps_per_epoch)
estimator.train(
input_fn=train_fn,
steps=self.steps_per_epoch * self.max_num_training_epochs,
hooks=[
train_hook,
])
Это привело к следующей ошибке:
Traceback (most recent call last):
File "dev/google_communicator/worker.py", line 160, in <module>
main()
File "dev/google_communicator/worker.py", line 133, in main
results = evaluator.eval(inputs, outputs)
File "/darch/deep_architect/contrib/misc/evaluators/tensorflow/tpu_estimator_classification.py", line 278, in eval
train_hook,
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
rendezvous.raise_errors()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
six.reraise(typ, value, traceback)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
saving_listeners=saving_listeners
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1241, in _train_model_default
saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1468, in _train_with_estimator_spec
log_step_count_steps=log_step_count_steps) as mon_sess:
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 504, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 921, in __init__
stop_grace_period_secs=stop_grace_period_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 631, in __init__
h.begin()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/estimator/python/estimator/hooks.py", line 135, in begin
self._input_fn, self._hooks, checkpoint_path=None)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1484, in _evaluate_build_graph
self._call_model_fn_eval(input_fn, self.config))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1520, in _call_model_fn_eval
features, labels, model_fn_lib.ModeKeys.EVAL, config)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2195, in _call_model_fn
features, labels, mode, config)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1195, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2631, in _model_fn
rendezvous=self._rendezvous[mode]),
KeyError: 'eval'
Есть ли лучший способ получить статистику проверки для каждой эпохи с TPU? Если нет, как вы должны выполнять проверку?
Редактировать : я, похоже, обошел эту ошибку, выполнив estimator.train()
и estimator.evaluate()
за один шаг без крючка, а затем выполнив полное обучение с крючком. К сожалению, послеПри первой оценке возникает ошибка при перезапуске обучения:
Traceback (most recent call last):
File "dev/google_communicator/worker.py", line 160, in <module>
main()
File "dev/google_communicator/worker.py", line 133, in main
results = evaluator.eval(inputs, outputs)
File "/darch/deep_architect/contrib/misc/evaluators/tensorflow/tpu_estimator_classification.py", line 329, in eval
train_hook,
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
rendezvous.raise_errors()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
six.reraise(typ, value, traceback)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
saving_listeners=saving_listeners
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1241, in _train_model_default
saving_listeners)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 1471, in _train_with_estimator_spec
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 671, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1156, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1255, in run
raise six.reraise(*original_exc_info)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1240, in run
return self._sess.run(*args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1312, in run
run_metadata=run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 1076, in run
return self._sess.run(*args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1152, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1328, in _do_run
run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1348, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.FailedPreconditionError: The TPU system has not been initialized.
[[{{node TPUReplicate/_compile/_14248540389241865347/_28}} = TPUCompile[NumDynamicShapes=0, Tguaranteed_constants=[], function=cluster_18378946049549366873_f15n_0[], metadata="\n\006\010...6\323\352L", num_computations=1, _device="/job:worker/replica:0/task:0/device:CPU:0"](^cluster/control_before/_0)]]
[[{{node tpu_compile_succeeded_assert/_1897752282630996029/_29_G679}} = _Recv[client_terminated=false, recv_device="/job:worker/replica:0/task:0/device:TPU:2", send_device="/job:worker/replica:0/task:0/device:CPU:0", send_device_incarnation=2337451129362726278, tensor_name="edge_174_tpu_compile_succeeded_assert/_1897752282630996029/_29", tensor_type=DT_FLOAT, _device="/job:worker/replica:0/task:0/device:TPU:2"]()]]
Чтобы уточнить, перед тем, как выдается ошибка, происходит следующее: два инициализирующих поезда и оценки обращаются к оценщику, обучение для одной эпохи,оценка на проверочном наборе. Когда оценщик пытается перезапустить обучение, выдается это исключение.
Эта открытая проблема может быть актуальной: https://github.com/tensorflow/tensor2tensor/issues/1202