Как сохранить средний вес контрольных точек с помощью Tensorflow 2.1? - PullRequest
0 голосов
/ 23 марта 2020

Я пытаюсь загрузить контрольные точки и сохранить их средние веса, используя TF2.1. Я нашел версию TF1 для этого. https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py

Переменная «контрольные точки» представляет собой список путей контрольных точек

  # Read variables from all checkpoints and average them.
  logger.info("Reading variables and averaging checkpoints:")
  for c in checkpoints:
    logger.info(c)

  var_list = tf.train.list_variables(checkpoints[0])

  var_values, var_dtypes = {}, {}
  for (name, shape) in var_list:
    if not name.startswith("global_step"):
      var_values[name] = tf.zeros(shape)

  for checkpoint in checkpoints:
    reader = tf.train.load_checkpoint(checkpoint)
    for name in var_values:
      tensor = tf.convert_to_tensor(reader.get_tensor(name))

      if tensor.dtype == tf.string:
        var_values[name] = tensor
      else:
        var_values[name] = tf.cast(var_values[name], tensor.dtype)
        var_values[name] += tensor
      var_dtypes[name] = tensor.dtype
    logger.info("Read from checkpoint %s", checkpoint)

  for name in var_values:  # Average.
    if var_dtypes[name] != tf.string:
      var_values[name] /= len(checkpoints)

В холодном виде вы объясняете, как сохранить среднее значение var_values в контрольной точке?

1 Ответ

0 голосов
/ 27 марта 2020

Я мог бы сохранить среднюю контрольную точку, ссылаясь на версию Keras того же вопроса, поскольку Tensorflow 2.1 следует API Keras.

URL: Средние веса в моделях keras

...