Как разбить модель Keras на две отдельные части и сохранить их отдельно? - PullRequest
0 голосов
/ 07 мая 2020

Итак, у меня модель Keras на основе VGG 16. Архитектура ниже. Я хотел бы сохранить отдельно основу модели и голову (только последние несколько полностью связанных слоев).

Чтобы сохранить базу, я делаю следующее, что, кажется, работает:

base_model = load_model(model_name)
x = base_model.input
y = base_model.get_layer('second_hidden_dropout').output
base = Model(inputs = x, outputs = y)
base.save('base.h5')

Мне трудно сохранить голову отдельно, так как я получаю сообщение об ошибке, когда запускаю следующее:

    head = Model(inputs = y, outputs = base_model.outputs)
    head.save('head.h5')

ValueError: Graph disconnected: cannot obtain value for tensor Tensor("input_1_1:0", shape=(?, 256, 96, 3), dtype=float32) at layer "input_1". The following previous layers were accessed without issue: []

Причина. Я использую трансферное обучение, чтобы несколько раз повторно обучать последние несколько слоев для разных целей. Чтобы сэкономить время вычислений, я бы хотел отдельно сохранить выходные данные базы модели для моих изображений и запускать разные головы отдельно.

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 256, 96, 3)   0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 256, 96, 64)  1792        input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 256, 96, 64)  36928       block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_pool (MaxPooling2D)      (None, 128, 48, 64)  0           block1_conv2[0][0]               
__________________________________________________________________________________________________
block2_conv1 (Conv2D)           (None, 128, 48, 128) 73856       block1_pool[0][0]                
__________________________________________________________________________________________________
block2_conv2 (Conv2D)           (None, 128, 48, 128) 147584      block2_conv1[0][0]               
__________________________________________________________________________________________________
block2_pool (MaxPooling2D)      (None, 64, 24, 128)  0           block2_conv2[0][0]               
__________________________________________________________________________________________________
block3_conv1 (Conv2D)           (None, 64, 24, 256)  295168      block2_pool[0][0]                
__________________________________________________________________________________________________
block3_conv2 (Conv2D)           (None, 64, 24, 256)  590080      block3_conv1[0][0]               
__________________________________________________________________________________________________
block3_conv3 (Conv2D)           (None, 64, 24, 256)  590080      block3_conv2[0][0]               
__________________________________________________________________________________________________
block3_pool (MaxPooling2D)      (None, 32, 12, 256)  0           block3_conv3[0][0]               
__________________________________________________________________________________________________
block4_conv1 (Conv2D)           (None, 32, 12, 512)  1180160     block3_pool[0][0]                
__________________________________________________________________________________________________
block4_conv2 (Conv2D)           (None, 32, 12, 512)  2359808     block4_conv1[0][0]               
__________________________________________________________________________________________________
block4_conv3 (Conv2D)           (None, 32, 12, 512)  2359808     block4_conv2[0][0]               
__________________________________________________________________________________________________
block4_pool (MaxPooling2D)      (None, 16, 6, 512)   0           block4_conv3[0][0]               
__________________________________________________________________________________________________
block5_conv1 (Conv2D)           (None, 16, 6, 512)   2359808     block4_pool[0][0]                
__________________________________________________________________________________________________
block5_conv2 (Conv2D)           (None, 16, 6, 512)   2359808     block5_conv1[0][0]               
__________________________________________________________________________________________________
block5_conv3 (Conv2D)           (None, 16, 6, 512)   2359808     block5_conv2[0][0]               
__________________________________________________________________________________________________
block5_pool (MaxPooling2D)      (None, 8, 3, 512)    0           block5_conv3[0][0]               
__________________________________________________________________________________________________
pool (GlobalMaxPooling2D)       (None, 512)          0           block5_pool[0][0]                
__________________________________________________________________________________________________
first_hidden_layer (Dense)      (None, 512)          262656      pool[0][0]                       
__________________________________________________________________________________________________
first_hidden_dropout (Dropout)  (None, 512)          0           first_hidden_layer[0][0]         
__________________________________________________________________________________________________
second_hidden_layer (Dense)     (None, 256)          131328      first_hidden_dropout[0][0]       
==================================================================================================
Total params: 15,114,840
Trainable params: 15,114,840
Non-trainable params: 0
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...