Загрузка в нетерпеливом TensorFlow сломана прямо сейчас? - PullRequest
1 голос
/ 17 апреля 2019

Веса в классах, унаследованных от tf.keras.Model, похоже, не могут загрузить в данный момент.Я не могу загрузить веса из Example () вне класса, используя контрольные точки, поэтому я попытался сделать это внутри, что по всем учетным записям должно работать.Он может сохранять веса, как при сохранении Example (), но все равно не может их загрузить.Это код моей модели:

class Example(tf.keras.Model):
    def __init__(self, cfg):
        super(Example, self).__init__()

        self.model = tf.keras.Sequential([
             ........layers.......
        ])

        # Create saver
        self.save_path = cfg.save_dir + cfg.extension
        self.ckpt_prefix = self.save_path + '/ckpt'
        self.saver = tf.train.Checkpoint(model=self.model)

    def call(self, x_in):
        x_out = self.model(x_in)
        return x_out

    def save(self):
        self.saver.save(file_prefix=self.ckpt_prefix)

    def load(self):
        self.saver.restore(tf.train.latest_checkpoint(self.save_path))

И это то, что я использую, чтобы проверить, загружается ли он:

example = Example()
if Path(self.example.save_path).is_dir():
            print(self.example.weights)
            print(self.example.model.weights)
            self.example.load()
            print(self.example.weights)
            print(self.example.model.weights)

Вывод:

[]
[]
[]
[]

Это было проверенона обоих тензорных потоках 1.3 и 2.0, и я могу подтвердить, что веса не пустые после первой партии, а также что это контрольная точка / сохранение.

1 Ответ

1 голос
/ 18 апреля 2019

Как оказалось, TensorFlow выполняет три разных способа проверки, в зависимости от того, что проверяется.

  1. Объект контрольной точки - это просто переменная. Это восстанавливается сразу после вызова checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path)).

  2. Объект контрольной точки - это модель с заданной формой ввода. Это также восстанавливается немедленно.

  3. Объект с контрольной точкой - это модель без заданной входной формы. Именно здесь поведение меняется, так как TensorFlow выполняет «отложенное» восстановление и НЕ будет восстанавливать вес модели, пока входные данные не будут переданы в модель.

Вот пример:

import os
import tensorflow as tf
import numpy as np

# Disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)
tf.enable_eager_execution()

# Create model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty before training?", model.weights == [])

# Create optim, checkpoint
optimizer = tf.train.AdamOptimizer(0.001)
checkpoint = tf.train.Checkpoint(model=model)

# Make fake data
img = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
truth = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
# Train
with tf.GradientTape() as tape:
    logits = model(img)
    loss = tf.losses.mean_squared_error(truth, logits)

# Compute/apply gradients
grads = tape.gradient(loss, model.trainable_weights)
grads_and_vars = zip(grads, model.trainable_weights)
optimizer.apply_gradients(grads_and_vars)

# Save model
checkpoint_path = './ckpt/'
checkpoint.save('./ckpt/')

# Check if weights update
print("Are weights empty after training?", model.weights == [])

# Reset model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty when resetting model?", model.weights == [])

# Update checkpoint pointer
checkpoint = tf.train.Checkpoint(model=model)
# Restore values from the checkpoint
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))

# This next line is REQUIRED to restore
#model(img)

print("Are weights empty after restoring from checkpoint?", model.weights == [])
print(status)
status.assert_existing_objects_matched()
status.assert_consumed()

С выходом:

Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? True
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f6256b4ddd8>
Traceback (most recent call last):
  File "test.py", line 58, in <module>
    status.assert_consumed()
  File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed
    raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
  name: "VARIABLE_VALUE"
  full_name: "sequential/conv2d/kernel"
  checkpoint_key: "model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
}

Однако, раскомментируя строку model(img), вы получите следующий вывод:

Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? False
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7ff62320fe48>

Таким образом, входные данные должны быть переданы для правильного восстановления модели, инвариантной формы.

Ссылки:

https://www.tensorflow.org/alpha/guide/checkpoints#delayed_restorations https://github.com/tensorflow/tensorflow/issues/27937

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...