Вместо последовательного API вы можете попробовать функциональный API, предоставив форму ввода для 1D-CNN и 2D-CNN только на первом уровне. Затем, добавив выравнивание перед плотным слоем в 1D-CNN и 2D-CNN, вы решите свою проблему.
Вы можете следовать нижеприведенному измененному коду.
1D-CNN:
# ----------------------- 1D CNN ----------------------
in_1D = Input((7380, 128000))
# 1
model_1D = Conv1D(32, kernel_size= 5 , strides=1, activation='relu')(in_1D)
model_1D = MaxPooling1D(pool_size= 4, strides=4)(model_1D)
# 2
model_1D = Conv1D(32, kernel_size= 5 , strides=1 , activation='relu')(model_1D)
model_1D = MaxPooling1D(pool_size= 4, strides=4)(model_1D)
# 3
model_1D = Conv1D(64, kernel_size= 5 , strides=1 , activation='relu')(model_1D)
model_1D = MaxPooling1D(pool_size= 4, strides=4)(model_1D)
# 4
model_1D = Conv1D(64, kernel_size= 5 , strides=1 , activation='relu')(model_1D)
model_1D = MaxPooling1D(pool_size= 2, strides=2)(model_1D)
# 5
model_1D = Conv1D(128, kernel_size= 5 , strides= 1 , activation='relu')(model_1D)
model_1D = MaxPooling1D(pool_size= 2, strides= 2)(model_1D)
# 6
model_1D = Conv1D(128, kernel_size= 5 , strides= 1 , activation='relu')(model_1D)
model_1D = MaxPooling1D(pool_size= 2, strides= 2)(model_1D)
model_1D = Flatten()(model_1D)
model_1D = Dense(9 , activation='softmax')(model_1D)
2D-CNN:
# ----------------------- 2D CNN ----------------------
in_2D = Input((7380, 128, 251))
model_2D = Conv2D(32, kernel_size=(3, 3) , strides=(1,1), activation='relu')(in_2D)
model_2D = Conv2D(32, kernel_size=(3, 3) , strides=(1,1), activation='relu')(model_2D)
model_2D = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(model_2D)
model_2D = Conv2D(32, kernel_size=(3, 3) , strides=(1,1), activation='relu')(model_2D)
model_2D = Conv2D(32, kernel_size=(3, 3) , strides=(1,1), activation='relu')(model_2D)
model_2D = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(model_2D)
model_2D = Flatten()(model_2D)
model_2D = Dense(9 , activation='relu')(model_2D)
Слияние:
merged = Concatenate()([model_2D, model_1D])
output = Dense(7, activation='softmax')(merged)
model_combined = Model(inputs=[in_2D, in_1D], outputs=[output])
model_combined.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model_combined.summary()
Выход:
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 7380, 128000) 0
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 7376, 32) 20480032 input_1[0][0]
__________________________________________________________________________________________________
max_pooling1d_1 (MaxPooling1D) (None, 1844, 32) 0 conv1d_1[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 1840, 32) 5152 max_pooling1d_1[0][0]
__________________________________________________________________________________________________
max_pooling1d_2 (MaxPooling1D) (None, 460, 32) 0 conv1d_2[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, 456, 64) 10304 max_pooling1d_2[0][0]
__________________________________________________________________________________________________
input_2 (InputLayer) (None, 7380, 128, 25 0
__________________________________________________________________________________________________
max_pooling1d_3 (MaxPooling1D) (None, 114, 64) 0 conv1d_3[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 7378, 126, 32 72320 input_2[0][0]
__________________________________________________________________________________________________
conv1d_4 (Conv1D) (None, 110, 64) 20544 max_pooling1d_3[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 7376, 124, 32 9248 conv2d_1[0][0]
__________________________________________________________________________________________________
max_pooling1d_4 (MaxPooling1D) (None, 55, 64) 0 conv1d_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 3688, 62, 32) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv1d_5 (Conv1D) (None, 51, 128) 41088 max_pooling1d_4[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 3686, 60, 32) 9248 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
max_pooling1d_5 (MaxPooling1D) (None, 25, 128) 0 conv1d_5[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 3684, 58, 32) 9248 conv2d_3[0][0]
__________________________________________________________________________________________________
conv1d_6 (Conv1D) (None, 21, 128) 82048 max_pooling1d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 1842, 29, 32) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
max_pooling1d_6 (MaxPooling1D) (None, 10, 128) 0 conv1d_6[0][0]
__________________________________________________________________________________________________
flatten_2 (Flatten) (None, 1709376) 0 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 1280) 0 max_pooling1d_6[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 9) 15384393 flatten_2[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 9) 11529 flatten_1[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 18) 0 dense_2[0][0]
dense_1[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 7) 133 concatenate_1[0][0]
==================================================================================================
Total params: 36,135,287
Trainable params: 36,135,287
Non-trainable params: 0