неверный скрипт tennors2tensor / avg_checkpoints - PullRequest
0 голосов
/ 05 ноября 2019

Я обучаю две подмодели параллельно, и хочу усреднить их после того, как они оба будут закончены.

Модель реализована из тензорного тензора (но все еще использует тензорный поток). Часть определения выглядит следующим образом:

    def build_model(self):
        layer_sizes = [n,n,n]
        kernel_sizes = [n,n,n]
        tf.reset_default_graph()
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.num_classes = num_classes
            # placeholder for parameter
            self.learning_rate = tf.placeholder(tf.float32, name="learning_rate")
            self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
            self.phase = tf.placeholder(tf.bool, name="phase")
            # Placeholders for regular data
            self.input_x = tf.placeholder(tf.float32, [None, None, input_feature_dim], name="input_x")
            self.input_y = tf.placeholder(tf.float32, [None, num_classes], name="input_y")
            h = self.input_x
            ......[remaining codes]

Я сохраняю его следующим образом:

    def save_model(sess, output):
        saver = tf.train.Saver()
        save_path = saver.save(sess, os.path.join(output, 'model'))

Когда я загружаю его, я использую:

    def load_model(self, sess, input_dir, logger):
        if logger is not None:
            logger.info("Start loading graph ...")
        saver = tf.train.import_meta_graph(os.path.join(input_dir, 'model.meta'))
        saver.restore(sess, os.path.join(input_dir, 'model'))
        self.graph = sess.graph
        self.input_x = self.graph.get_tensor_by_name("input_x:0")
        self.input_y = self.graph.get_tensor_by_name("input_y:0")
        self.num_classes = self.input_y.shape[1]
        self.learning_rate = self.graph.get_tensor_by_name("learning_rate:0")
        self.dropout_keep_prob = self.graph.get_tensor_by_name("dropout_keep_prob:0")
        self.phase = self.graph.get_tensor_by_name("phase:0")
        self.loss = self.graph.get_tensor_by_name("loss:0")
        self.optimizer = self.graph.get_operation_by_name("optimizer")
        self.accuracy = self.graph.get_tensor_by_name("accuracy/accuracy:0")

Я используюavg_checkpoint для усреднения двух подмоделей:

python utils/avg_checkpoints.py
  --checkpoints path/to/checkpoint1,path/to/checkpoint2
  --num_last_checkpoints 2
  --output_path where/to/save/the/output

Но я сталкиваюсь с проблемами при дальнейшей проверке кода avg_checkpoints:

    for checkpoint in checkpoints:
        reader = tf.train.load_checkpoint(checkpoint)
        for name in var_values:
            tensor = reader.get_tensor(name)
            var_dtypes[name] = tensor.dtype
            var_values[name] += tensor
        tf.logging.info("Read from checkpoint %s", checkpoint)
    for name in var_values:  # Average.
        var_values[name] /= len(checkpoints)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        tf_vars = [
            tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
            for v in var_values
        ]
    placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
    assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
    global_step = tf.Variable(
        0, name="global_step", trainable=False, dtype=tf.int64)
    saver = tf.train.Saver(tf.all_variables())

    # Build a model consisting only of variables, set them to the average values.
    with tf.Session() as sess:
        for p, assign_op, (name, value) in zip(placeholders, assign_ops,
                                               six.iteritems(var_values)):
            sess.run(assign_op, {p: value})
        # Use the built saver to save the averaged checkpoint.
        saver.save(sess, FLAGS.output_path, global_step=global_step)

Сохраняются только переменные, а не все тензоры. Например, когда я загружаю его с помощью вышеуказанной функции load_model, он не может иметь тензор «input_x: 0». Этот скрипт неверный, или я должен изменить его в зависимости от моего использования?

Я использую TF r1.13. Спасибо!

...