Я наткнулся на следующий код и мне было интересно, что именно keras.layers.concatenate
делает в этом случае.
Наилучшее предположение:
- В
fire_module()
, y
обучения на основе каждого пикселя (kernel_size=1
) y1
обучения на основе каждого пикселя activation map
из y
(kernel_size=1
) y3
обучается на основе области 3x3 пикселей activation map
из y
(kernel_size=3
) concatenate
объединяет y1
и y3
вместе, что означает общее filters
теперь является суммой фильтров в y1
и y3
- Эта конкатенация представляет собой среднее значение для обучения, основанного на каждом пикселе, обучения, основанного на 3x3, оба основаны на предыдущей карте активации, основанной на каждый пиксель, что делает модель лучше?
Любая помощь приветствуется.
def fire(x, squeeze, expand):
y = Conv2D(filters=squeeze, kernel_size=1, activation='relu', padding='same')(x)
y = BatchNormalization(momentum=bnmomemtum)(y)
y1 = Conv2D(filters=expand//2, kernel_size=1, activation='relu', padding='same')(y)
y1 = BatchNormalization(momentum=bnmomemtum)(y1)
y3 = Conv2D(filters=expand//2, kernel_size=3, activation='relu', padding='same')(y)
y3 = BatchNormalization(momentum=bnmomemtum)(y3)
return concatenate([y1, y3])
def fire_module(squeeze, expand):
return lambda x: fire(x, squeeze, expand)
x = Input(shape=[144, 144, 3])
y = BatchNormalization(center=True, scale=False)(x)
y = Activation('relu')(y)
y = Conv2D(kernel_size=5, filters=16, padding='same', use_bias=True, activation='relu')(x)
y = BatchNormalization(momentum=bnmomemtum)(y)
y = fire_module(16, 32)(y)
y = MaxPooling2D(pool_size=2)(y)
Редактировать:
Чтобы быть немного более точным c, почему бы не иметь это:
# why not this?
def fire(x, squeeze, expand):
y = Conv2D(filters=squeeze, kernel_size=1, activation='relu', padding='same')(x)
y = BatchNormalization(momentum=bnmomemtum)(y)
y = Conv2D(filters=expand//2, kernel_size=1, activation='relu', padding='same')(y)
y = BatchNormalization(momentum=bnmomemtum)(y)
y = Conv2D(filters=expand//2, kernel_size=3, activation='relu', padding='same')(y)
y = BatchNormalization(momentum=bnmomemtum)(y)
return y