Расщепление модели Кераса на произвольный слой - PullRequest
0 голосов
/ 30 апреля 2019

Я пытаюсь создать функцию, которая разделяет модель Keras на указанный пользователем слой.У меня есть следующий код:

def return_split_models(model, layer):
    model_f, model_h = Sequential(), Sequential()
    for current_layer in range(0, layer+1):
        model_f.add(model.layers[current_layer])
    for current_layer in range(layer+1, len(model.layers)):
        model_h.add(model.layers[current_layer])
    return model_f, model_h

Однако, когда мы вернемся model_h и вызовем сводку, мы увидим ValueError, что модель никогда не вызывалась.Глядя на другие посты, кажется, что это связано со входными данными для model_h, однако я не могу найти примеров, обобщающих какой-либо определенный слой.У кого-нибудь есть указания?

1 Ответ

1 голос
/ 30 апреля 2019

Вам нужно добавить InputLayer к model_h.

from keras.layers import InputLayer

def return_split_models(model, layer):
    model_f, model_h = Sequential(), Sequential()
    for current_layer in range(0, layer+1):
        model_f.add(model.layers[current_layer])
    # add input layer
    model_h.add(InputLayer(input_shape=model.layers[layer+1].input_shape[1:]))
    for current_layer in range(layer+1, len(model.layers)):
        model_h.add(model.layers[current_layer])
    return model_f, model_h

Пример:

model = Sequential()
model.add(Dense(50,input_shape=(100,)))
model.add(Dense(40))
model.add(Dense(30))
model.add(Dense(20))
model.add(Dense(10))

model_f, model_h = return_split_models(model, 2)
print(model_f.summary())
print(model_h.summary())

# print
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 50)                5050      
_________________________________________________________________
dense_2 (Dense)              (None, 40)                2040      
_________________________________________________________________
dense_3 (Dense)              (None, 30)                1230      
=================================================================
Total params: 8,320
Trainable params: 8,320
Non-trainable params: 0
_________________________________________________________________
None
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 20)                620       
_________________________________________________________________
dense_5 (Dense)              (None, 10)                210       
=================================================================
Total params: 830
Trainable params: 830
Non-trainable params: 0
_________________________________________________________________
None
...