Функция tf.train.get_checkpoint_state
(строка 246) имеет следующую сигнатуру функции (как определено в файле checkpoint_management )
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
Args:
checkpoint_dir: The directory of checkpoints.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Returns:
A CheckpointState if the state was available, None
otherwise.
Raises:
ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
"""
Для меня не очень ясно, что именно latest_filename
ожидает, или как я могу получить конкретную контрольную точку, а не последнюю (если несколько контрольных точек находятся в одном каталоге)
Вторая строка вышеуказанной функции:
coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
latest_filename)
, который определен в строке 43 как:
def _GetCheckpointFilename(save_dir, latest_filename):
"""Returns a filename for storing the CheckpointState.
Args:
save_dir: The directory for saving and restoring checkpoints.
latest_filename: Name of the file in 'save_dir' that is used
to store the CheckpointState.
Returns:
The path of the file that contains the CheckpointState proto.
"""
if latest_filename is None:
latest_filename = "checkpoint"
return os.path.join(save_dir, latest_filename)
Так что это сужает тип до строки.
API tf.estimator.Estimator
сохраняет контрольные точки в виде:
model.ckpt-<#####>.index
model.ckpt-<#####>.meta
model.ckpt-<#####>.data-<#####>-of-<#####>
Так что, если я использую эти контрольные точки, если я хочу указать , какую контрольную точку, я могу позвонить:
tf.train.get_checkpoint_state(estimator_model_dir, latest_filename="model.ckpt-<#####>")
Я пробовал:
CHECKPOINT_DIR = "path/to/checkpoints"
ckpt_num = "model.ckpt-#####"
file = ckpt_num
# file = ckpt_num + 'data-00000-of-00001'
# file = ckpt_num + 'index'
# file = ckpt_num + 'meta'
checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR, file)
checkpoint.model_checkpoint_path
все из которых выбрасывают ошибки, например
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa1 in position 0: invalid start byte
Оставляя файл, я получаю последнюю контрольную точку, которая может быть не той, которую я хочу ...