У меня есть классификатор, который я обучил, используя керас, который работает очень хорошо.Он использует keras.applications.MobileNetV2
.
Этот классификатор хорошо обучен примерно на 200 категориях и имеет высокую точность.
Однако я хотел бы использовать слои извлечения объектов из этого классификатора как частьмодели обнаружения объектов.
Я использовал API обнаружения объектов Tensorflow и изучал модель SSDLite + MobileNetV2.Я могу начать тренировку, но тренировка очень медленная, и большая часть потерь приходится на этап классификации.
Я хотел бы назначить веса из моей модели keras .h5
дляФункция Извлечение слоя MobileNetV2 в Tensorflow, но я не уверен, что лучший способ сделать это.
Я могу легко загрузить файл h5
и получить список имен слоев:
import keras
keras_model = keras.models.load_model("my_classifier.h5")
keras_names = [l.name for l in keras_model.layers]
print(keras_names)
Я также могу восстановить контрольную точку тензорного потока из API обнаружения объектов и экспортировать слои с весами:
tf.reset_default_graph()
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('models/model.ckpt.meta')
what = new_saver.restore(sess, 'models/model.ckpt')
tf_names = []
for op in sess.graph.get_operations():
if "MobilenetV2" in op.name and "Assign" in op.name:
tf_names.append(op.name)
print(tf_names)
Не получается получить хорошее соответствие между именами слоев из керасов и изtensorflow.Даже если бы я мог, я не уверен в следующих шагах.
Если бы кто-нибудь мог дать мне несколько советов о том, как лучше подойти к этому, я был бы очень благодарен.
Обновление:
Я последовал предложению Шарки, приведенному ниже, с небольшим изменением:
new_saver = tf.train.import_meta_graph(os.path.join(keras_checkpoint_dir, 'keras_model.ckpt.meta'))
new_saver.restore(sess, os.path.join(keras_checkpoint_dir, tf.train.latest_checkpoint(keras_checkpoint_dir)))
Однако, к сожалению, теперь я получаю эту ошибку:
NotFoundError (см. Выше для отслеживания): Восстановлениес контрольной точки не удалось.Скорее всего, это связано с отсутствием в контрольной точке имени переменной или другого ключа графика.Пожалуйста, убедитесь, что вы не изменили ожидаемый график на основе контрольной точки.Исходная ошибка:
Ключевой FeatureExtractor / MobilenetV2 / extended_conv_6 / projects / BatchNorm / гамма не найдена в контрольной точке [[узел сохранения / RestoreV2_295 (определенный в: 7) = RestoreV2 [dtypes = [DT_FLOAT], _device = "/задание: localhost / реплика: 0 / задача: 0 / устройство: ЦП: 0 "] (_arg_save / Const_0_0, save / RestoreV2_295 / тензор_имя, save / RestoreV2_295 / shape_and_slices)]] [[{{узел save / RestoreV2_196 / _393}}= _Recvclient_terminated = false, recv_device = "/ job: localhost / replica: 0 / task: 0 / device: GPU: 0", send_device = "/ job: localhost / replica: 0 / task: 0 / device: CPU: 0", send_device_incarnation = 1, тензор_имя = "edge_789_save / RestoreV2_196", тензор_тип = DT_FLOAT, _device = "/ job: localhost / реплика: 0 / task: 0 / устройство: GPU: 0"]]
Любые идеи о том, как избавиться от этой ошибки?