Получить (входной) слой модели Keras после сохранения и перезагрузки его с диска - PullRequest
1 голос
/ 15 мая 2019

Я загрузил, расширил, обучил сеть VGG16 через Keras, затем сохранил ее на диск:

from keras.applications import VGG16
from keras import models

conv_base = VGG16(weights="imagenet", include_top=False)
model = models.Sequential()
model.add(conv_base)
...
model.compile(...)
model.fit(...)
model.save("saved_model.h5")

В другом сценарии я загружаю эту обученную модель снова:

from keras.models import load_model

model_vgg16 = load_model("saved_model.h5")
model_fails = model_vgg16.get_layer("vgg16")
model_fails.input

Эта последняя строка приводит к следующему исключению:

AttributeError: Layer vgg16 has multiple inbound nodes, hence the notion of "layer input" is ill-defined. Use `get_input_at(node_index)` instead.

Однако, когда я делаю то же самое для сети VGG16 напрямую, тогда она работает нормально:

from keras.applications import VGG16
from keras.models import load_model

model_works = VGG16(weights='imagenet', include_top=False)
model_works.input

Эта последняя строка непривести к ошибке.Итак, мой вопрос:
Как получить доступ к (входному) слою сохраненной, а затем повторно загруженной модели Keras?

Ответы [ 2 ]

1 голос
/ 15 мая 2019

После добавления модели VGG16 в вашу пользовательскую модель у нее будет два входных узла: один - это исходный входной узел, доступный с использованием conv_base.get_input_at(0), и другой входной узел, который создается для ввода из вашей пользовательской модели, который будетбыть доступным, используя conv_base.get_input_at(1) (это фактически ввод модели и эквивалентно model.input).Разница между узлом и слоем в Keras была подробно объяснена в этом ответе .

0 голосов
/ 15 мая 2019

Мой подход заключается в том, чтобы сначала напечатать имя всех слоев модели, а затем вызвать слой по его имени.

Например:

from keras.models import load_model
model_vgg16 = load_model("saved_model.h5")
mdoel_vgg16.summary()

Запишите имя нужного слоя и затем получите вход или выход слоя

layer_input = model_vgg16.get_layer('vgg16').get_layer(layer_name).input
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...