Можно ли преобразовать обученную модель в TensorFlow в объект, который можно использовать для трансферного обучения? - PullRequest
0 голосов
/ 04 марта 2020

Мне интересно использовать трансферное обучение, как описано здесь: https://www.tensorflow.org/tutorials/images/transfer_learning

Проблема в том, что модель, которую я пытаюсь использовать в качестве базовой модели, не является таковой из известных встроенных моделей Keras, таких как MobileNetV2. Таким образом, я предполагаю, что мне нужно сделать следующий первый шаг (шаг 1), чтобы иметь возможность делать то, что упомянуто в учебном пособии для трансферного обучения (шаги 2-6).
1. Загрузите модель из каталога, содержащего файлы Saved_Model.
2. Заморозьте модель (сделайте неизменяемыми ее обучаемые параметры)
3. Создайте отдельный слой и поместите его поверх замороженной модели
4. Обучите полученную модель.
5. Сохраните только что обученную модель.
6. Сделайте предсказания, используя только что обученную модель.

Мой вопрос касается первого шага. Я получаю сообщение об ошибке, которое не понимаю, как это исправить при попытке загрузить модель, используя следующие Python коды / сценарии:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
tf.saved_model.load(
    export_dir='/dir_to_the_model_files/', tags=None
)

Ошибка:

OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..

Я также думаю, что может быть способ конвертировать файлы TensorFlow, включая (save_model.ckpt-0.data-00000-of-00001), в файлы, которые можно прочитать с помощью API Keras (например, в формате h5py.File), которые могут облегчить передачу обучения, аналогичную упомянутому учебнику. Таким образом, я мог бы применить аналогичный метод к следующим, чтобы извлечь базовую модель и выполнить следующие шаги.

base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

Или предпочтительно использовать следующий метод из https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model:

tf.keras.models.load_model(
    filepath, custom_objects=None, compile=True
)

Обновление: Я попробовал следующий метод, но он не работает (tf был импортирован с использованием совместимой версии import tensorflow.compat.v1. as tf):

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/dir_to_the_model_files/saved_model.ckpt-0.meta')
    saver.restore(sess, "/dir_to_the_model_files/saved_model.ckpt-0")
    loaded = tf.saved_model.load(sess,tags=None,export_dir="/dir_to_the_model_files",import_scope=None)

Возвращает следующие предупреждения и ошибки:

WARNING:tensorflow:The saved meta_graph is possibly from an older release:
'metric_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
INFO:tensorflow:Restoring parameters from /dir_to_the_model_files/saved_model.ckpt-0
<tensorflow.python.training.saver.Saver object at 0x2aaab4824a50>
WARNING:tensorflow:From <ipython-input-3-b8fd24f6b841>:9: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.

OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..

1 Ответ

0 голосов
/ 05 марта 2020

Может помочь документация TensorFlow для tf.saved_model.load:

SavedModels из API tf.estimator.Estimator или 1.x SavedModel имеют плоский график вместо tf.function объекты. Эти SavedModels будут иметь функции, соответствующие их сигнатурам в атрибуте .signatures, но также имеют метод .prune, который позволяет извлекать функции для новых подграфов. Это эквивалентно импорту SavedModel и именования каналов и выборок в сеансе из TensorFlow 1.x.

Возможно, вам придется использовать устаревший вызов API v1 https://www.tensorflow.org/api_docs/python/tf/compat/v1/saved_model/load

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...