Не удается загрузить контрольную точку tenorsflow keras при использовании tf.distribute.MirroredStrategy () - PullRequest
1 голос
/ 10 марта 2020

Я пытаюсь загрузить модель tf.keras (v1.15.0) из контрольной точки, созданной с помощью обратного вызова ModelCheckpoint, изменить ее, удалив несколько слоев и добавив новые, а затем продолжить обучение новой задаче. Я использую tf.distribute.MirroredStrategy () для распределенного обучения с 2 gpus.

strategy = tensorflow.distribute.MirroredStrategy()
with strategy.scope():

    # Load pretrained model from checkpoint
    model = get_model()
    model.load_weights('file_name.hdf5')

    # Chop off some layers, add new layers
    model = modify_pretrained_model(model)

    model.compile(optimizer=opt, loss=loss)

Модель отлично загружается и компилируется, и я могу запустить model.summary (), но когда я вызываю модель .fit () или model.predict () В моем стеке python появляются следующие ошибки:

  (0) Failed precondition: Error while reading resource variable compression0_conv0_batchnorm/moving_variance from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/compression0_conv0_batchnorm/moving_variance/N10tensorflow3VarE does not exist.
     [[{{node time_distributed_1/model_1/compression0_conv0_batchnorm/FusedBatchNormV3/ReadVariableOp_1}}]]
     [[dense_1_1/Sigmoid/_225]]
  (1) Failed precondition: Error while reading resource variable compression0_conv0_batchnorm/moving_variance from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/compression0_conv0_batchnorm/moving_variance/N10tensorflow3VarE does not exist.
     [[{{node time_distributed_1/model_1/compression0_conv0_batchnorm/FusedBatchNormV3/ReadVariableOp_1}}]]
0 successful operations.
1 derived errors ignored

Эта проблема , похоже, решает эту проблему точно, но без использования tf .распределение, чтобы продолжить обучение.

Когда я создаю экземпляр сеанса за пределами области распространения и устанавливаю ссылку на него внутри области распространения, код вылетает с той же ошибкой.

tf_config = some_custom_config
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()

strategy = tensorflow.distribute.MirroredStrategy()
with strategy.scope():

    set_session(sess)

    # Load pretrained model from checkpoint
    model = get_model()
    model.load_weights('file_name.hdf5')

    # Chop off some layers, add new layers
    model = modify_pretrained_model(model)

    model.compile(optimizer=opt, loss=loss)

1 Ответ

0 голосов
/ 11 марта 2020

Я провел хорошие 2-3 дня, пытаясь понять это. Единственное, что действительно сработало, это обновление до tf 2.0.0. Тогда все работало как волхвы c. В качестве последнего средства я смог обучить первую модель, добавить и удалить дополнительные слои, перекомпилировать и продолжить обучение в том же исполнении python с той же стратегией распространения, но так и не смог перезагрузить tf.keras ModelCheckpoint, используя стратегии распределения в тф 1.15.0.

...