Функция вызова узла метрик утечки памяти в пользовательском оценщике TensorFlow - PullRequest
0 голосов
/ 22 ноября 2018

Я использую следующую функцию для расчета дополнительных показателей для моего обучения.Я создаю host_call с: host_call = (host_call_fn, metric_args) и передаю его в аргумент host_call оценщика.Однако вызов этого приводит к утечке памяти, и я не могу понять, в чем проблема.При использовании кучи кажется, что большие словари каким-то образом создаются и не выпускаются.

p_temp = tf.reshape(policy_loss, [1], name='policy_loss_reshape')
v_temp = tf.reshape(value_loss, [1], name='value_loss_reshape')
e_temp = tf.reshape(entropy_loss, [1], name='entropy_loss_reshape')
t_temp = tf.reshape(total_loss, [1], name='total_loss_reshape')
g_temp = tf.reshape(global_step, [1], name='global_step_reshape')
#
metric_args = [p_temp, v_temp, e_temp, t_temp, g_temp]

host_call_fn = functools.partial(
  eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.TRAIN)
host_call = (host_call_fn, metric_args)

Следующая функция вычисляет дополнительные метрики оценки и записывает их в сводный каталог для Tensorboard.

def eval_metrics_host_call_fn(p_temp,
                            v_temp,
                            e_temp,
                            t_temp,
                            step,
                            est_mode=tf.estimator.ModeKeys.TRAIN):
#
with tf.variable_scope('metrics'):
  metric_ops = {
      'policy_loss': tf.metrics.mean(p_temp, name='policy_loss_metric'),
      'value_loss': tf.metrics.mean(v_temp, name='value_loss_metric'),
      'entropy_loss': tf.metrics.mean(e_temp, name='entropy_loss_metric'),
      'total_loss': tf.metrics.mean(t_temp, name='total_loss_metric')
  }
if est_mode == tf.estimator.ModeKeys.EVAL:
  return metric_ops
eval_step = tf.reduce_min(step)
# Create summary ops so that they show up in SUMMARIES collection
# That way, they get logged automatically during training
summary_writer = summary.create_file_writer(FLAGS.summary_dir)
with summary_writer.as_default(
), summary.record_summaries_every_n_global_steps(FLAGS.summary_steps,
                                                 eval_step):
  for metric_name, metric_op in metric_ops.items():
    summary.scalar(metric_name, metric_op[1], step=eval_step)
# Reset metrics occasionally so that they are mean of recent batches.
reset_op = tf.variables_initializer(tf.local_variables('metrics'))
cond_reset_op = tf.cond(
    tf.equal(eval_step % FLAGS.summary_steps, tf.to_int64(1)),
    lambda: reset_op, lambda: tf.no_op())

return summary.all_summary_ops() + [cond_reset_op]
...