Я использую следующую функцию для расчета дополнительных показателей для моего обучения.Я создаю host_call с: host_call = (host_call_fn, metric_args) и передаю его в аргумент host_call оценщика.Однако вызов этого приводит к утечке памяти, и я не могу понять, в чем проблема.При использовании кучи кажется, что большие словари каким-то образом создаются и не выпускаются.
p_temp = tf.reshape(policy_loss, [1], name='policy_loss_reshape')
v_temp = tf.reshape(value_loss, [1], name='value_loss_reshape')
e_temp = tf.reshape(entropy_loss, [1], name='entropy_loss_reshape')
t_temp = tf.reshape(total_loss, [1], name='total_loss_reshape')
g_temp = tf.reshape(global_step, [1], name='global_step_reshape')
#
metric_args = [p_temp, v_temp, e_temp, t_temp, g_temp]
host_call_fn = functools.partial(
eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.TRAIN)
host_call = (host_call_fn, metric_args)
Следующая функция вычисляет дополнительные метрики оценки и записывает их в сводный каталог для Tensorboard.
def eval_metrics_host_call_fn(p_temp,
v_temp,
e_temp,
t_temp,
step,
est_mode=tf.estimator.ModeKeys.TRAIN):
#
with tf.variable_scope('metrics'):
metric_ops = {
'policy_loss': tf.metrics.mean(p_temp, name='policy_loss_metric'),
'value_loss': tf.metrics.mean(v_temp, name='value_loss_metric'),
'entropy_loss': tf.metrics.mean(e_temp, name='entropy_loss_metric'),
'total_loss': tf.metrics.mean(t_temp, name='total_loss_metric')
}
if est_mode == tf.estimator.ModeKeys.EVAL:
return metric_ops
eval_step = tf.reduce_min(step)
# Create summary ops so that they show up in SUMMARIES collection
# That way, they get logged automatically during training
summary_writer = summary.create_file_writer(FLAGS.summary_dir)
with summary_writer.as_default(
), summary.record_summaries_every_n_global_steps(FLAGS.summary_steps,
eval_step):
for metric_name, metric_op in metric_ops.items():
summary.scalar(metric_name, metric_op[1], step=eval_step)
# Reset metrics occasionally so that they are mean of recent batches.
reset_op = tf.variables_initializer(tf.local_variables('metrics'))
cond_reset_op = tf.cond(
tf.equal(eval_step % FLAGS.summary_steps, tf.to_int64(1)),
lambda: reset_op, lambda: tf.no_op())
return summary.all_summary_ops() + [cond_reset_op]