Как загрузить обученную модель, сохраненную с помощью export_inference_graph.py? - PullRequest
1 голос
/ 23 апреля 2020

Ниже приведен пример, в котором используется API обнаружения объектов tenorflow 1.15.0. В учебном пособии подробно рассматриваются следующие аспекты:

  • как загрузить модель
  • , как загрузить пользовательскую базу данных с файлами. xml, создать из них файлы .cvs и затем .record файлы
  • как настроить обучающий конвейер
  • как получить графики тензорной доски
  • как обучить net сохранению контрольных точек (используя model_main.py)
  • как экспортировать (сохранить) модель (используя export_inference_graph.py)

Что я не смог сделать sh, однако загружает сохраненную модель для ее использования. Я пробовал с tf.saved_model.loader.load(sess, flags, export_dir, но я получаю

INFO:tensorflow:Saver not created because there are no variables in the graph to restore.
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.

папка, указанная в export_dir, имеет следующую структуру:

+dir
   +saved_model
      -saved_model.pb
   -model.ckpt.data-00000-of-00001
   -model.ckpt.index
   -checkpoint
   -frozen_inference_graph.pb
   -model.ckpt.meta
   -pipeline.config

Моя конечная цель - захватить изображения с камеры и подайте их на net для обнаружения объектов в реальном времени. \ В качестве промежуточного шага, теперь я просто хочу иметь возможность подать одно изображение и получить вывод. Я смог тренировать net, но теперь не могу его использовать.

Заранее спасибо.

1 Ответ

1 голос
/ 23 апреля 2020

Я нашел пример того, как загрузить модель , которая пропустила мне go через нее. \ Поскольку формат папки файла, который загружен в этом примере, тот же, я получаю свой код Мне просто пришлось его адаптировать.

Порядковая функция, которая загружает модель:

def load_model(model_name):
  base_url = 'http://download.tensorflow.org/models/object_detection/'
  model_file = model_name + '.tar.gz'
  model_dir = tf.keras.utils.get_file(
    fname=model_name, 
    origin=base_url + model_file,
    untar=True)

  model_dir = pathlib.Path(model_dir)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model

Затем я использовал эту функцию для создания этой новой

def load_local_model(model_path):
  model_dir = pathlib.Path(model_path)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model

Сначала это не сработало, поскольку tf.saved_model.load ожидал 3 аргумента, но это было решено путем импорта двух блоков import в одном и том же примере. Я не знаю, какой импорт сработал и почему (я я отредактирую этот ответ, когда получу его), но на данный момент этот код работает, и пример позволяет делать больше вещей.

Блоки импорта:

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from IPython.display import display

и

from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

РЕДАКТИРОВАТЬ То, что действительно нужно для этого работало, был следующий блок.

import os
import pathlib


if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models

%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.

%%bash 
cd models/research
pip install .

В противном случае этот блок импорта не будет работать

from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
...