TensorFlow v1.10 +: get_checkpoint_state, каково ожидаемое значение для last_filename для API Estimator, чтобы указать конкретную контрольную точку - PullRequest
1 голос
/ 27 июня 2019

Функция tf.train.get_checkpoint_state (строка 246) имеет следующую сигнатуру функции (как определено в файле checkpoint_management )

def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.
  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.
  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.
  Returns:
    A CheckpointState if the state was available, None
    otherwise.
  Raises:
    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
  """

Для меня не очень ясно, что именно latest_filename ожидает, или как я могу получить конкретную контрольную точку, а не последнюю (если несколько контрольных точек находятся в одном каталоге)

Вторая строка вышеуказанной функции:

coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                     latest_filename)

, который определен в строке 43 как:

def _GetCheckpointFilename(save_dir, latest_filename):
  """Returns a filename for storing the CheckpointState.
  Args:
    save_dir: The directory for saving and restoring checkpoints.
    latest_filename: Name of the file in 'save_dir' that is used
      to store the CheckpointState.
  Returns:
    The path of the file that contains the CheckpointState proto.
  """
  if latest_filename is None:
    latest_filename = "checkpoint"
  return os.path.join(save_dir, latest_filename)

Так что это сужает тип до строки.

API tf.estimator.Estimator сохраняет контрольные точки в виде:

model.ckpt-<#####>.index
model.ckpt-<#####>.meta
model.ckpt-<#####>.data-<#####>-of-<#####>

Так что, если я использую эти контрольные точки, если я хочу указать , какую контрольную точку, я могу позвонить:

tf.train.get_checkpoint_state(estimator_model_dir, latest_filename="model.ckpt-<#####>")

Я пробовал:

CHECKPOINT_DIR = "path/to/checkpoints"
ckpt_num = "model.ckpt-#####"

file = ckpt_num
# file = ckpt_num + 'data-00000-of-00001'
# file = ckpt_num + 'index'
# file = ckpt_num + 'meta'

checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR, file)
checkpoint.model_checkpoint_path

все из которых выбрасывают ошибки, например

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa1 in position 0: invalid start byte

Оставляя файл, я получаю последнюю контрольную точку, которая может быть не той, которую я хочу ...

...