Можно ли загрузить и обучить модель, если у нас есть только файлы контрольных точек? - PullRequest
0 голосов
/ 27 марта 2020

Можно ли загрузить и обучить модель из файлов контрольных точек? У нас есть информация о входном и выходном тензорной форме.

Файлы контрольных точек

Ответы [ 2 ]

0 голосов
/ 27 марта 2020

Да, это возможно, если контрольная точка содержит параметры модели (параметры как W и b в W * x + b). Я думаю, что у вас есть, что в случае передачи обучения , вы можете использовать это на основе ваших файлов.

# Loads the weights
model.load_weights(checkpoint_path)

Вы должны знать архитектуру модели и создать модель перед использованием это. В некоторых моделях есть определенный c способ загрузки контрольной точки.

Также, проверьте это: https://www.tensorflow.org/tutorials/keras/save_and_load

0 голосов
/ 27 марта 2020

Да. Вы можете использовать tenorflow-keras, следуя этому примеру.

https://www.tensorflow.org/guide/checkpoint

Непосредственно из документации по тензорному потоку.

Список контрольных точек

!ls ./tf_ckpts

, который производит

checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Восстановление с контрольной точки

Вызов restore() для tf.train.Checkpoint объекта ставит в очередь запрошенные восстановления, восстанавливая значения переменных, как только будет найден соответствующий путь от объекта Checkpoint. Например, мы можем загрузить только смещение из модели, которую мы определили выше, восстановив один путь к нему через сеть и слой.

to_restore = tf.Variable(tf.zeros([5])) # variables from your model. 
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net) 
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/')) 
print(to_restore.numpy())  # We get the restored value now

Чтобы дважды проверить, что оно было восстановлено, вы можете набрать:

status.assert_existing_objects_matched()

и получите следующий вывод.

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1d796da278>
...