Как оказалось, TensorFlow выполняет три разных способа проверки, в зависимости от того, что проверяется.
Объект контрольной точки - это просто переменная. Это восстанавливается сразу после вызова checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))
.
Объект контрольной точки - это модель с заданной формой ввода. Это также восстанавливается немедленно.
Объект с контрольной точкой - это модель без заданной входной формы. Именно здесь поведение меняется, так как TensorFlow выполняет «отложенное» восстановление и НЕ будет восстанавливать вес модели, пока входные данные не будут переданы в модель.
Вот пример:
import os
import tensorflow as tf
import numpy as np
# Disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)
tf.enable_eager_execution()
# Create model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(256, 3, padding="same"),
tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty before training?", model.weights == [])
# Create optim, checkpoint
optimizer = tf.train.AdamOptimizer(0.001)
checkpoint = tf.train.Checkpoint(model=model)
# Make fake data
img = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
truth = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
# Train
with tf.GradientTape() as tape:
logits = model(img)
loss = tf.losses.mean_squared_error(truth, logits)
# Compute/apply gradients
grads = tape.gradient(loss, model.trainable_weights)
grads_and_vars = zip(grads, model.trainable_weights)
optimizer.apply_gradients(grads_and_vars)
# Save model
checkpoint_path = './ckpt/'
checkpoint.save('./ckpt/')
# Check if weights update
print("Are weights empty after training?", model.weights == [])
# Reset model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(256, 3, padding="same"),
tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty when resetting model?", model.weights == [])
# Update checkpoint pointer
checkpoint = tf.train.Checkpoint(model=model)
# Restore values from the checkpoint
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))
# This next line is REQUIRED to restore
#model(img)
print("Are weights empty after restoring from checkpoint?", model.weights == [])
print(status)
status.assert_existing_objects_matched()
status.assert_consumed()
С выходом:
Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? True
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f6256b4ddd8>
Traceback (most recent call last):
File "test.py", line 58, in <module>
status.assert_consumed()
File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed
raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
name: "VARIABLE_VALUE"
full_name: "sequential/conv2d/kernel"
checkpoint_key: "model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
}
Однако, раскомментируя строку model(img)
, вы получите следующий вывод:
Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? False
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7ff62320fe48>
Таким образом, входные данные должны быть переданы для правильного восстановления модели, инвариантной формы.
Ссылки:
https://www.tensorflow.org/alpha/guide/checkpoints#delayed_restorations
https://github.com/tensorflow/tensorflow/issues/27937