использование интерфейса оценщика для вывода с предварительно обученной моделью обнаружения тензорных объектов - PullRequest
1 голос
/ 01 мая 2019

Я пытаюсь загрузить предварительно обученную модель обнаружения тензорного объекта из репозитория Tensorflow Object Detection как tf.estimator.Estimator и использовать ее для прогнозирования.

Я могу загрузить модель и выполнить логический вывод, используя Estimator.predict(), однако вывод - мусор. Другие способы загрузки модели, например, как Predictor, так и работающий вывод работают нормально.

Любая помощь по правильной загрузке модели как Estimator, вызывающая predict(), была бы очень признательна. Мой текущий код:

Загрузить и подготовить изображение

def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(list(image.getdata())).reshape((im_height, im_width, 3)).astype(np.uint8)

image_url = 'https://i.imgur.com/rRHusZq.jpg'

# Load image
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))

# Format original image size
im_size_orig = np.array(list(image.size) + [1])
im_size_orig = np.expand_dims(im_size_orig, axis=0)
im_size_orig = np.int32(im_size_orig)

# Resize image
image = image.resize((np.array(image.size) / 4).astype(int))

# Format image
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_np_expanded = np.float32(image_np_expanded)

# Stick into feature dict
x = {'image': image_np_expanded, 'true_image_shape': im_size_orig}

# Stick into input function
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
    x=x,
    y=None,
    shuffle=False,
    batch_size=128,
    queue_capacity=1000,
    num_epochs=1,
    num_threads=1,
)

Примечание:

train_and_eval_dict также, кажется, содержит input_fn для прогноза

train_and_eval_dict['predict_input_fn']

Однако на самом деле это возвращает tf.estimator.export.ServingInputReceiver, с чем я не уверен, что делать. Это может быть источником моих проблем, поскольку перед тем, как модель действительно увидит изображение, потребуется немало предварительной обработки.

Загрузить модель как Estimator

Модель загружена с TF Model Zoo здесь , код для загрузки модели адаптирован с здесь .

model_dir = './pretrained_models/tensorflow/ssd_mobilenet_v1_coco_2018_01_28/'
pipeline_config_path = os.path.join(model_dir, 'pipeline.config')

config = tf.estimator.RunConfig(model_dir=model_dir)

train_and_eval_dict = model_lib.create_estimator_and_inputs(
    run_config=config,
    hparams=model_hparams.create_hparams(None),
    pipeline_config_path=pipeline_config_path,
    train_steps=None,
    sample_1_of_n_eval_examples=1,
    sample_1_of_n_eval_on_train_examples=(5))

estimator = train_and_eval_dict['estimator']

Выполнить вывод

output_dict1 = estimator.predict(predict_input_fn)

Это распечатывает некоторые сообщения журнала, одно из которых:

INFO:tensorflow:Restoring parameters from ./pretrained_models/tensorflow/ssd_mobilenet_v1_coco_2018_01_28/model.ckpt

Похоже, предварительно загруженные грузы загружаются. Однако результаты выглядят так:

Image with bad detections

Загрузить ту же модель, что и Predictor

from tensorflow.contrib import predictor

model_dir = './pretrained_models/tensorflow/ssd_mobilenet_v1_coco_2018_01_28'
saved_model_dir = os.path.join(model_dir, 'saved_model')
predict_fn = predictor.from_saved_model(saved_model_dir)

Выполнить вывод

output_dict2 = predict_fn({'inputs': image_np_expanded})

Результаты выглядят хорошо:

enter image description here

1 Ответ

1 голос
/ 01 мая 2019

Когда вы загружаете модель как оценщик и из файла контрольных точек, вот функция восстановления, связанная с ssd моделями. От ssd_meta_arch.py

def restore_map(self,
                  fine_tune_checkpoint_type='detection',
                  load_all_detection_checkpoint_vars=False):
    """Returns a map of variables to load from a foreign checkpoint.
    See parent class for details.
    Args:
      fine_tune_checkpoint_type: whether to restore from a full detection
        checkpoint (with compatible variable names) or to restore from a
        classification checkpoint for initialization prior to training.
        Valid values: `detection`, `classification`. Default 'detection'.
      load_all_detection_checkpoint_vars: whether to load all variables (when
         `fine_tune_checkpoint_type='detection'`). If False, only variables
         within the appropriate scopes are included. Default False.
    Returns:
      A dict mapping variable names (to load from a checkpoint) to variables in
      the model graph.
    Raises:
      ValueError: if fine_tune_checkpoint_type is neither `classification`
        nor `detection`.
    """
    if fine_tune_checkpoint_type not in ['detection', 'classification']:
      raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
          fine_tune_checkpoint_type))

    if fine_tune_checkpoint_type == 'classification':
      return self._feature_extractor.restore_from_classification_checkpoint_fn(
          self._extract_features_scope)

    if fine_tune_checkpoint_type == 'detection':
      variables_to_restore = {}
      for variable in tf.global_variables():
        var_name = variable.op.name
        if load_all_detection_checkpoint_vars:
          variables_to_restore[var_name] = variable
        else:
          if var_name.startswith(self._extract_features_scope):
            variables_to_restore[var_name] = variable

    return variables_to_restore

Как вы можете видеть, даже если в файле конфигурации установлено значение from_detection_checkpoint: True, будут восстановлены только переменные в области выделения функций. Чтобы восстановить все переменные, вам нужно установить

load_all_detection_checkpoint_vars: True

в конфигурационном файле.

Итак, вышеуказанная ситуация вполне понятна. При загрузке модели в виде Estimator будут восстановлены только переменные из области действия средства извлечения объектов, а веса области действия предикторов не будут восстановлены, очевидно, что оценщик будет давать случайные прогнозы.

При загрузке модели в качестве предиктора загружаются все веса, поэтому прогнозы являются разумными.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...