Как сделать BatchNormalisation обучаемым, когда все остальные промежуточные слои (в случае Re sNet, Dens eNet) заморожены? - PullRequest
0 голосов
/ 06 мая 2020
from keras.applications.densenet import DenseNet201
conv_base = DenseNet201(weights= 'imagenet', include_top=False, input_shape= (200,200,3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())

model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(3, activation='softmax'))


conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
  if layer.name == 'conv4_block40_0_bn':
    set_trainable = True
  if set_trainable:
    layer.trainable = True
  else:
    layer.trainable = False

Я использую его, чтобы заморозить верхние слои DenseNet201. Теперь как добавить или удалить BatchNormalisation / Dropout из предварительно обученной (DenseNet201) модели. Как сделать обучаемым или нет из любого места, где он присутствует? Как должен выглядеть код?

from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras import optimizers
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(200,200),
    batch_size=20,
    class_mode='categorical',
    shuffle = True
)
validation_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(200,200),
    batch_size=20,
    class_mode='categorical',
    shuffle = True)
checkpoint = ModelCheckpoint('model-{epoch:03d}-{acc:03f}-{val_acc:03f}.h5', verbose=1, monitor='val_loss',save_best_only=True, mode='auto')
model.compile(loss='categorical_crossentropy',optimizer=optimizers.RMSprop(lr=2e-6),metrics=['acc'])
history = model.fit_generator(
    train_generator,
    steps_per_epoch=1180//20,
    epochs=100,
    validation_data=validation_generator,
    validation_steps=290//20,
    callbacks=[checkpoint]
)

Я использую ImageDataGenerator для получения пакетных данных. Я новичок в глубоком обучении.

...