Сверточно-классификационная модель в основной модели - PullRequest
0 голосов
/ 04 марта 2019

Мне нужно создать модель нейронной сети, например:

convolution --> classification
       \            /
        \          /
        _\|      |/_
         third model
       with one output

Свертка выводит данные, которые используются в качестве входных данных для модели классификации.После этого результаты свертки и классификации заполняются (объединяются) для третьей модели.Третья модель выдаст прогноз 0..1, который используется для обучения всей сети.

  • Прежде всего: Возможно ли правильно распространить модель классификации обратно, в этой ситуации? Или это требует создания трех отдельных моделей?
  • Я пытался объединить свертку и классификацию, но без хороших результатов.Я получил сообщение об ошибке «Отключен график».

Полный журнал ошибок: «Отключен график: невозможно получить значение для тензорного тензора («ification_prediction_Input_2: 0», shape = (1, 512), dtype = float32) на слое "ification_prediction_Input ". Следующие предыдущие слои были доступны без проблем: []".

Если идея верна, как соединить модели, как на "графике"?

Мой код всейчас:

# state convolution
state_input = Input(shape=INPUT_SHAPE, name='state_input', batch_shape=(1, 210, 160, 3))
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)
state_convolution_model = Model(state_input, state_outputs, name='state_convolution_model')
state_convolution_model.compile(optimizer='adam', loss='mean_squared_error', metrics=['acc'])

state_convolution_model_input = Input(shape=INPUT_SHAPE, name='state_convolution_model_input', batch_shape=(1, 210, 160, 3))
state_convolution = state_convolution_model(state_convolution_model_input)

# classification output
classficication_Input = Input(shape=(1, LSTM_OUTPUT_DIM), batch_shape=(1, LSTM_OUTPUT_DIM), name='classification_prediction_Input')
classficication_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(classficication_Input)
classficication_output_raw = Dense(ACTIONS, activation='sigmoid', name='classification_output_raw')(classficication_Dense_1)
classficication_output = Reshape((ACTIONS,), name='classification_output')(classficication_output_raw)
classficication_model = Model(classficication_Input, classficication_output, name='classificationPrediction_model')
classficication_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

classficicationPrediction = classficication_model(state_convolution)

i = keras.layers.concatenate([state_outputs, classficication_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o)                  # <-- graph error is here
plot_model(model, to_file='model.png', show_shapes=True)

1 Ответ

0 голосов
/ 04 марта 2019

Да, вы можете построить такую ​​структуру и обучить ее сквозным способом.Однако вам нужно создать одну модель, имеющую несколько ветвей.Другая проблема, которую я вижу, состоит в том, что вы компилируете модель до того, как она будет полностью определена.Вот рабочий код:

# state convolution                                                                                                                                                                                                                                                   
state_input = Input(shape=INPUT_SHAPE, name='state_input')
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)

# classification output                                                                                                                                                                                                                                               
classification_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(state_outputs)
classification_output_raw = Dense(ACTIONS,                                                                                                                                                                                                                            
                                  activation='sigmoid',                                                                                                                                                                                                               
                                  name='classification_output_raw')(classification_Dense_1)
classification_output = Reshape((ACTIONS,), name='classification_output')(classification_output_raw)


i = concatenate([state_outputs, classification_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o)                  # <-- no graph error anymore here                                                                                                                                                                                      
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
model.summary()

Вывод:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
state_input (InputLayer)        (None, 210, 160, 3)  0                                            
__________________________________________________________________________________________________
state_Conv2D_1 (Conv2D)         (None, 51, 39, 8)    1544        state_input[0][0]                
__________________________________________________________________________________________________
state_MaxPooling2D_1 (MaxPoolin (None, 25, 19, 8)    0           state_Conv2D_1[0][0]             
__________________________________________________________________________________________________
state_Flatten (Flatten)         (None, 3800)         0           state_MaxPooling2D_1[0][0]       
__________________________________________________________________________________________________
classification_prediction_Dense (None, 32)           121632      state_Flatten[0][0]              
__________________________________________________________________________________________________
classification_output_raw (Dens (None, 4)            132         classification_prediction_Dense_1
__________________________________________________________________________________________________
classification_output (Reshape) (None, 4)            0           classification_output_raw[0][0]  
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 3804)         0           state_Flatten[0][0]              
                                                                 classification_output[0][0]      
__________________________________________________________________________________________________
dense (Dense)                   (None, 32)           121760      concatenate[0][0]                
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1)            33          dense[0][0]                      
==================================================================================================

См. Руководство по функциональному API для дополнительных примеров.

...