Мне интересно использовать трансферное обучение, как описано здесь: https://www.tensorflow.org/tutorials/images/transfer_learning
Проблема в том, что модель, которую я пытаюсь использовать в качестве базовой модели, не является таковой из известных встроенных моделей Keras, таких как MobileNetV2. Таким образом, я предполагаю, что мне нужно сделать следующий первый шаг (шаг 1), чтобы иметь возможность делать то, что упомянуто в учебном пособии для трансферного обучения (шаги 2-6).
1. Загрузите модель из каталога, содержащего файлы Saved_Model.
2. Заморозьте модель (сделайте неизменяемыми ее обучаемые параметры)
3. Создайте отдельный слой и поместите его поверх замороженной модели
4. Обучите полученную модель.
5. Сохраните только что обученную модель.
6. Сделайте предсказания, используя только что обученную модель.
Мой вопрос касается первого шага. Я получаю сообщение об ошибке, которое не понимаю, как это исправить при попытке загрузить модель, используя следующие Python коды / сценарии:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
tf.saved_model.load(
export_dir='/dir_to_the_model_files/', tags=None
)
Ошибка:
OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..
Я также думаю, что может быть способ конвертировать файлы TensorFlow, включая (save_model.ckpt-0.data-00000-of-00001), в файлы, которые можно прочитать с помощью API Keras (например, в формате h5py.File), которые могут облегчить передачу обучения, аналогичную упомянутому учебнику. Таким образом, я мог бы применить аналогичный метод к следующим, чтобы извлечь базовую модель и выполнить следующие шаги.
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
Или предпочтительно использовать следующий метод из https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model:
tf.keras.models.load_model(
filepath, custom_objects=None, compile=True
)
Обновление: Я попробовал следующий метод, но он не работает (tf был импортирован с использованием совместимой версии import tensorflow.compat.v1. as tf
):
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/dir_to_the_model_files/saved_model.ckpt-0.meta')
saver.restore(sess, "/dir_to_the_model_files/saved_model.ckpt-0")
loaded = tf.saved_model.load(sess,tags=None,export_dir="/dir_to_the_model_files",import_scope=None)
Возвращает следующие предупреждения и ошибки:
WARNING:tensorflow:The saved meta_graph is possibly from an older release:
'metric_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
INFO:tensorflow:Restoring parameters from /dir_to_the_model_files/saved_model.ckpt-0
<tensorflow.python.training.saver.Saver object at 0x2aaab4824a50>
WARNING:tensorflow:From <ipython-input-3-b8fd24f6b841>:9: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..