Я пытаюсь загрузить контрольные точки и сохранить их средние веса, используя TF2.1. Я нашел версию TF1 для этого. https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py
Переменная «контрольные точки» представляет собой список путей контрольных точек
# Read variables from all checkpoints and average them.
logger.info("Reading variables and averaging checkpoints:")
for c in checkpoints:
logger.info(c)
var_list = tf.train.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
if not name.startswith("global_step"):
var_values[name] = tf.zeros(shape)
for checkpoint in checkpoints:
reader = tf.train.load_checkpoint(checkpoint)
for name in var_values:
tensor = tf.convert_to_tensor(reader.get_tensor(name))
if tensor.dtype == tf.string:
var_values[name] = tensor
else:
var_values[name] = tf.cast(var_values[name], tensor.dtype)
var_values[name] += tensor
var_dtypes[name] = tensor.dtype
logger.info("Read from checkpoint %s", checkpoint)
for name in var_values: # Average.
if var_dtypes[name] != tf.string:
var_values[name] /= len(checkpoints)
В холодном виде вы объясняете, как сохранить среднее значение var_values
в контрольной точке?