Как мне соединить две модели керас в одну модель? - PullRequest
1 голос
/ 26 апреля 2020

Допустим, у меня есть модель ResNet50, и я sh подключу выходной слой этой модели к входному слою модели VGG.

Это модель Re sNet и выходной тензор ResNet50:

img_shape = (164, 164, 3)
resnet50_model = ResNet50(include_top=False, input_shape=img_shape, weights = None)

print(resnet50_model.output.shape)

Я получаю вывод:

TensorShape([Dimension(None), Dimension(6), Dimension(6), Dimension(2048)])

Теперь я хочу новый слой где я изменяю этот выходной тензор в (64,64,18)

Затем у меня есть модель VGG16:

VGG_model = VGG_model = VGG16(include_top=False, weights=None)

Я хочу, чтобы выходные данные ResNet50 преобразились в желаемый тензор и подается как вход в модель VGG. По сути, я хочу объединить две модели. Может ли кто-нибудь помочь мне сделать это? Спасибо!

1 Ответ

1 голос
/ 27 апреля 2020

Есть несколько способов сделать это. Вот один из способов использования API последовательной модели:

import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16

model = tf.keras.Sequential()
img_shape = (164, 164, 3)
model.add(ResNet50(include_top=False, input_shape=img_shape, weights = None))

model.add(tf.keras.layers.Reshape(target_shape=(64,64,18)))
model.add(tf.keras.layers.Conv2D(3,kernel_size=(3,3),name='Conv2d'))

VGG_model = VGG16(include_top=False, weights=None)
model.add(VGG_model)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.summary()

Сводка модели выглядит следующим образом

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
resnet50 (Model)             (None, 6, 6, 2048)        23587712  
_________________________________________________________________
reshape (Reshape)            (None, 64, 64, 18)        0         
_________________________________________________________________
Conv2d (Conv2D)              (None, 62, 62, 3)         489       
_________________________________________________________________
vgg16 (Model)                multiple                  14714688  
=================================================================
Total params: 38,302,889
Trainable params: 38,249,769
Non-trainable params: 53,120
_________________________________________________________________

Полный код здесь .

...