Да.Нужно быть осторожным, как сбрасывать переменные при обработке данных в пакетном режиме.Организация операций при расчете общих метрик (т. Е. Точности, точности или auc) и пакетных метрик различна.Необходимо сбросить текущие переменные до нуля, прежде чем рассчитывать значения точности каждой новой партии данных.
При tf.metrics.precision
создаются две рабочие переменные и помещаются в вычислительный график: true_positives
и false_positives
.Таким образом, вы можете выбрать переменные для сброса, используя scope
аргумент tf.get_collection()
.
import tensorflow as tf
import numpy as np
import numpy as np
import tensorflow as tf
labels = np.array([[1,1,1,0],
[1,1,1,0],
[1,1,1,0],
[1,1,1,0]], dtype=np.uint8)
predictions = np.array([[1,0,0,0],
[1,1,0,0],
[1,1,1,0],
[0,1,1,1]], dtype=np.uint8)
precision, update_op = tf.metrics.precision(labels, predictions, name = 'precision')
print(precision)
#Tensor("precision/value:0", shape=(), dtype=float32)
print(update_op)
#Tensor("precision/update_op:0", shape=(), dtype=float32)
tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
#[<tf.Variable 'precision/true_positives/count:0' shape=() dtype=float32_ref>,
# <tf.Variable 'precision/false_positives/count:0' shape=() dtype=float32_ref>,
running_vars_precision = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='precision')
running_vars_auc_initializer = tf.variables_initializer(var_list=running_vars_precision )
with tf.Session() as sess:
sess.run(running_vars_auc_initializer)
print("tf precision/update_op: {}".format(sess.run([precision, update_op])))
#tf precision/update_op: [0.8888889, 0.8888889]
print("tf precision: {}".format(sess.run(precision)))
#tf precision: 0.8888888955116272