Импорт весов из классификатора Keras в API обнаружения объектов TF - PullRequest
0 голосов
/ 08 апреля 2019

У меня есть классификатор, который я обучил, используя керас, который работает очень хорошо.Он использует keras.applications.MobileNetV2.

Этот классификатор хорошо обучен примерно на 200 категориях и имеет высокую точность.

Однако я хотел бы использовать слои извлечения объектов из этого классификатора как частьмодели обнаружения объектов.

Я использовал API обнаружения объектов Tensorflow и изучал модель SSDLite + MobileNetV2.Я могу начать тренировку, но тренировка очень медленная, и большая часть потерь приходится на этап классификации.

Я хотел бы назначить веса из моей модели keras .h5 дляФункция Извлечение слоя MobileNetV2 в Tensorflow, но я не уверен, что лучший способ сделать это.

Я могу легко загрузить файл h5 и получить список имен слоев:

import keras

keras_model = keras.models.load_model("my_classifier.h5")

keras_names = [l.name for l in keras_model.layers]

print(keras_names)

Я также могу восстановить контрольную точку тензорного потока из API обнаружения объектов и экспортировать слои с весами:

tf.reset_default_graph()

with tf.Session() as sess:

    new_saver = tf.train.import_meta_graph('models/model.ckpt.meta')

    what = new_saver.restore(sess, 'models/model.ckpt')


    tf_names = []
    for op in sess.graph.get_operations():
        if "MobilenetV2" in op.name and "Assign" in op.name:
            tf_names.append(op.name)

    print(tf_names)

Не получается получить хорошее соответствие между именами слоев из керасов и изtensorflow.Даже если бы я мог, я не уверен в следующих шагах.

Если бы кто-нибудь мог дать мне несколько советов о том, как лучше подойти к этому, я был бы очень благодарен.

Обновление:

Я последовал предложению Шарки, приведенному ниже, с небольшим изменением:

new_saver = tf.train.import_meta_graph(os.path.join(keras_checkpoint_dir, 'keras_model.ckpt.meta'))

new_saver.restore(sess, os.path.join(keras_checkpoint_dir, tf.train.latest_checkpoint(keras_checkpoint_dir)))

Однако, к сожалению, теперь я получаю эту ошибку:

NotFoundError (см. Выше для отслеживания): Восстановлениес контрольной точки не удалось.Скорее всего, это связано с отсутствием в контрольной точке имени переменной или другого ключа графика.Пожалуйста, убедитесь, что вы не изменили ожидаемый график на основе контрольной точки.Исходная ошибка:

Ключевой FeatureExtractor / MobilenetV2 / extended_conv_6 / projects / BatchNorm / гамма не найдена в контрольной точке [[узел сохранения / RestoreV2_295 (определенный в: 7) = RestoreV2 [dtypes = [DT_FLOAT], _device = "/задание: localhost / реплика: 0 / задача: 0 / устройство: ЦП: 0 "] (_arg_save / Const_0_0, save / RestoreV2_295 / тензор_имя, save / RestoreV2_295 / shape_and_slices)]] [[{{узел save / RestoreV2_196 / _393}}= _Recvclient_terminated = false, recv_device = "/ job: localhost / replica: 0 / task: 0 / device: GPU: 0", send_device = "/ job: localhost / replica: 0 / task: 0 / device: CPU: 0", send_device_incarnation = 1, тензор_имя = "edge_789_save / RestoreV2_196", тензор_тип = DT_FLOAT, _device = "/ job: localhost / реплика: 0 / task: 0 / устройство: GPU: 0"]]

Любые идеи о том, как избавиться от этой ошибки?

1 Ответ

1 голос
/ 08 апреля 2019

Вы можете использовать tf.keras.estimator.model_to_estimator

estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=path)
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, os.path.join(path/keras, tf.train.latest_checkpoint(path/keras)))
    print(tf.global_variables())

Это должно сделать работу. Обратите внимание, что он создаст подкаталог внутри первоначально указанного пути.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...