У меня есть модель на основе MobileNet для задачи регрессии:
def MobileNet_v1():
# Keras 2.1.6
mobilenet = MobileNet(input_shape=(config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS),
alpha=1.0,
depth_multiplier=1,
include_top=False,
weights='imagenet'
)
x = Flatten()(mobilenet.output)
x = Dropout(0.5)(x)
x = Dense(config.N_LANDMARKS * 2, activation='linear')(x)
# -------------------------------------------------------
model = Model(inputs=mobilenet.input, outputs=x)
optimizer = Adadelta()
model.compile(optimizer=optimizer, loss=mae_loss)
model.summary()
import sys
sys.exit()
return model
Структура сети:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 128, 128, 3) 0
_________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 130, 130, 3) 0
_________________________________________________________________
conv1 (Conv2D) (None, 64, 64, 32) 864
_________________________________________________________________
conv1_bn (BatchNormalization (None, 64, 64, 32) 128
_________________________________________________________________
conv1_relu (Activation) (None, 64, 64, 32) 0
_________________________________________________________________
conv_pad_1 (ZeroPadding2D) (None, 66, 66, 32) 0
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D) (None, 64, 64, 32) 288
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 64, 64, 32) 128
_________________________________________________________________
conv_dw_1_relu (Activation) (None, 64, 64, 32) 0
_________________________________________________________________
conv_pw_1 (Conv2D) (None, 64, 64, 64) 2048
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 64, 64, 64) 256
_________________________________________________________________
conv_pw_1_relu (Activation) (None, 64, 64, 64) 0
_________________________________________________________________
conv_pad_2 (ZeroPadding2D) (None, 66, 66, 64) 0
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D) (None, 32, 32, 64) 576
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 32, 32, 64) 256
_________________________________________________________________
conv_dw_2_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv_pw_2 (Conv2D) (None, 32, 32, 128) 8192
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_pw_2_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pad_3 (ZeroPadding2D) (None, 34, 34, 128) 0
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D) (None, 32, 32, 128) 1152
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_dw_3_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pw_3 (Conv2D) (None, 32, 32, 128) 16384
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_pw_3_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pad_4 (ZeroPadding2D) (None, 34, 34, 128) 0
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D) (None, 16, 16, 128) 1152
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 16, 16, 128) 512
_________________________________________________________________
conv_dw_4_relu (Activation) (None, 16, 16, 128) 0
_________________________________________________________________
conv_pw_4 (Conv2D) (None, 16, 16, 256) 32768
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_pw_4_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pad_5 (ZeroPadding2D) (None, 18, 18, 256) 0
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D) (None, 16, 16, 256) 2304
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_dw_5_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pw_5 (Conv2D) (None, 16, 16, 256) 65536
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_pw_5_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pad_6 (ZeroPadding2D) (None, 18, 18, 256) 0
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D) (None, 8, 8, 256) 2304
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 8, 8, 256) 1024
_________________________________________________________________
conv_dw_6_relu (Activation) (None, 8, 8, 256) 0
_________________________________________________________________
conv_pw_6 (Conv2D) (None, 8, 8, 512) 131072
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_6_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_7 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_7_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_7 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_7_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_8 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_8_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_8 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_8_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_9 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_9_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_9 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_9_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_10 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_10_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_10 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_10_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_11 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_11_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_11 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_11_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_12 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 4, 4, 512) 4608
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 4, 4, 512) 2048
_________________________________________________________________
conv_dw_12_relu (Activation) (None, 4, 4, 512) 0
_________________________________________________________________
conv_pw_12 (Conv2D) (None, 4, 4, 1024) 524288
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_pw_12_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv_pad_13 (ZeroPadding2D) (None, 6, 6, 1024) 0
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 4, 4, 1024) 9216
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_dw_13_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv_pw_13 (Conv2D) (None, 4, 4, 1024) 1048576
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_pw_13_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 16384) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 16384) 0
_________________________________________________________________
dense_1 (Dense) (None, 156) 2556060
=================================================================
Total params: 5,784,924
Trainable params: 5,763,036
Non-trainable params: 21,888
_________________________________________________________________
Как мы видим, около половины параметров сети находится в последнем плотном слое. Итак, мой вопрос: если я уже обучил сеть, как уменьшить размер модели? Я протестировал глобальный средний пул вместо плотного слоя, и для моей задачи регрессии он работает плохо, так что это не вариант, поэтому я с нетерпением жду чего-то вроде уменьшения размера плотного слоя или спарсификации плотного слоя.
Обновление:
Пример сети с глобальным средним пулом:
def MobileNet_v2():
# MobileNet with GAP layer on top
# Keras 2.1.6
mobilenet = MobileNet(input_shape=(config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS),
alpha=1.0,
depth_multiplier=1,
include_top=False,
weights='imagenet'
)
x = Conv2D(filters=config.N_LANDMARKS * 2, kernel_size=(1,1), activation='linear')(mobilenet.output)
x = GlobalAveragePooling2D()(x)
# -------------------------------------------------------
model = Model(inputs=mobilenet.input, outputs=x)
optimizer = Adadelta()
model.compile(optimizer=optimizer, loss=mae_loss)
model.summary()
import sys
sys.exit()
return model
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 128, 128, 3) 0
_________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 130, 130, 3) 0
_________________________________________________________________
conv1 (Conv2D) (None, 64, 64, 32) 864
_________________________________________________________________
conv1_bn (BatchNormalization (None, 64, 64, 32) 128
_________________________________________________________________
conv1_relu (Activation) (None, 64, 64, 32) 0
_________________________________________________________________
conv_pad_1 (ZeroPadding2D) (None, 66, 66, 32) 0
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D) (None, 64, 64, 32) 288
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 64, 64, 32) 128
_________________________________________________________________
conv_dw_1_relu (Activation) (None, 64, 64, 32) 0
_________________________________________________________________
conv_pw_1 (Conv2D) (None, 64, 64, 64) 2048
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 64, 64, 64) 256
_________________________________________________________________
conv_pw_1_relu (Activation) (None, 64, 64, 64) 0
_________________________________________________________________
conv_pad_2 (ZeroPadding2D) (None, 66, 66, 64) 0
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D) (None, 32, 32, 64) 576
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 32, 32, 64) 256
_________________________________________________________________
conv_dw_2_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv_pw_2 (Conv2D) (None, 32, 32, 128) 8192
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_pw_2_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pad_3 (ZeroPadding2D) (None, 34, 34, 128) 0
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D) (None, 32, 32, 128) 1152
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_dw_3_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pw_3 (Conv2D) (None, 32, 32, 128) 16384
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_pw_3_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pad_4 (ZeroPadding2D) (None, 34, 34, 128) 0
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D) (None, 16, 16, 128) 1152
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 16, 16, 128) 512
_________________________________________________________________
conv_dw_4_relu (Activation) (None, 16, 16, 128) 0
_________________________________________________________________
conv_pw_4 (Conv2D) (None, 16, 16, 256) 32768
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_pw_4_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pad_5 (ZeroPadding2D) (None, 18, 18, 256) 0
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D) (None, 16, 16, 256) 2304
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_dw_5_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pw_5 (Conv2D) (None, 16, 16, 256) 65536
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_pw_5_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pad_6 (ZeroPadding2D) (None, 18, 18, 256) 0
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D) (None, 8, 8, 256) 2304
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 8, 8, 256) 1024
_________________________________________________________________
conv_dw_6_relu (Activation) (None, 8, 8, 256) 0
_________________________________________________________________
conv_pw_6 (Conv2D) (None, 8, 8, 512) 131072
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_6_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_7 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_7_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_7 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_7_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_8 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_8_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_8 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_8_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_9 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_9_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_9 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_9_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_10 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_10_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_10 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_10_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_11 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_11_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_11 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_11_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_12 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 4, 4, 512) 4608
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 4, 4, 512) 2048
_________________________________________________________________
conv_dw_12_relu (Activation) (None, 4, 4, 512) 0
_________________________________________________________________
conv_pw_12 (Conv2D) (None, 4, 4, 1024) 524288
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_pw_12_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv_pad_13 (ZeroPadding2D) (None, 6, 6, 1024) 0
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 4, 4, 1024) 9216
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_dw_13_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv_pw_13 (Conv2D) (None, 4, 4, 1024) 1048576
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_pw_13_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 4, 4, 156) 159900
_________________________________________________________________
global_average_pooling2d_1 ( (None, 156) 0
=================================================================
Total params: 3,388,764
Trainable params: 3,366,876
Non-trainable params: 21,888