Сбрасывать локальные переменные метрик после каждой эпохи - PullRequest
0 голосов
/ 02 января 2019

Я использую встроенный метод tf.metrics.precision для оценки моей модели. Я просматривал определение , но локальные переменные никогда не сбрасывались.

Не следует ли их сбрасывать после каждой эпохи, чтобы убрать отсчеты с последних эпох? Это сделано автоматически, и я просто пропустил это в исходном коде, или я должен это сделать? Если последнее верно, как мне сбросить локальные переменные? Я ничего не читал об этом в документации.

Ответы [ 2 ]

0 голосов
/ 02 января 2019

Да.Нужно быть осторожным, как сбрасывать переменные при обработке данных в пакетном режиме.Организация операций при расчете общих метрик (т. Е. Точности, точности или auc) и пакетных метрик различна.Необходимо сбросить текущие переменные до нуля, прежде чем рассчитывать значения точности каждой новой партии данных.

При tf.metrics.precision создаются две рабочие переменные и помещаются в вычислительный график: true_positives и false_positives.Таким образом, вы можете выбрать переменные для сброса, используя scope аргумент tf.get_collection().

import tensorflow as tf
import numpy as np

import numpy as np
import tensorflow as tf

labels = np.array([[1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0],
                   [1,1,1,0]], dtype=np.uint8)

predictions = np.array([[1,0,0,0],
                        [1,1,0,0],
                        [1,1,1,0],
                        [0,1,1,1]], dtype=np.uint8)

precision, update_op = tf.metrics.precision(labels, predictions, name = 'precision')

print(precision)
#Tensor("precision/value:0", shape=(), dtype=float32)
print(update_op)
#Tensor("precision/update_op:0", shape=(), dtype=float32)

tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
#[<tf.Variable 'precision/true_positives/count:0' shape=() dtype=float32_ref>,
# <tf.Variable 'precision/false_positives/count:0' shape=() dtype=float32_ref>,

running_vars_precision = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='precision')
running_vars_auc_initializer = tf.variables_initializer(var_list=running_vars_precision )

with tf.Session() as sess:
    sess.run(running_vars_auc_initializer)
    print("tf precision/update_op: {}".format(sess.run([precision, update_op])))
    #tf precision/update_op: [0.8888889, 0.8888889]
    print("tf precision: {}".format(sess.run(precision)))
    #tf precision: 0.8888888955116272
0 голосов
/ 02 января 2019

Переменные для отслеживания метрик создаются с помощью функции metric_variable и, таким образом, добавляются в коллекцию с помощью ключа tf.GraphKeys.METRIC_VARIABLES. После того как вы определили все свои метрики, вы можете выполнить операцию сброса следующим образом:

reset_metrics_op = tf.variables_initializer(tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))

И запустить его после окончания каждой эпохи.

...