Обучение CNN с трансферным обучением в Керасе - ввод изображений не работает, а векторный ввод - PullRequest
1 голос
/ 06 апреля 2019

Я пытаюсь сделать перевод обучения в Керасе.Я установил сеть ResNet50, которая не может быть обучена с некоторыми дополнительными слоями:

# Image input
model = Sequential()
model.add(ResNet50(include_top=False, pooling='avg')) # output is 2048
model.add(Dropout(0.05))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.15))
model.add(Dense(512, activation='relu'))
model.add(Dense(7, activation='softmax'))
model.layers[0].trainable = False
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

Затем я создаю входные данные: x_batch, используя функцию ResNet50 preprocess_input вместе с одной меткой горячего кодирования y_batch и выполните подгонку следующим образом:

model.fit(x_batch,
          y_batch,
          epochs=nb_epochs,
          batch_size=64,
          shuffle=True,
          validation_split=0.2,
          callbacks=[lrate])

Точность обучения приближается к 100% после десяти или около того эпох, но точность валидации фактически снижается примерно с 50% до 30% при неуклонно возрастающих потерях валидации.

Однако, если я вместо этого создам сеть только с последними слоями:

# Vector input
model2 = Sequential()
model2.add(Dropout(0.05, input_shape=(2048,)))
model2.add(Dense(512, activation='relu'))
model2.add(Dropout(0.15))
model2.add(Dense(512, activation='relu'))
model2.add(Dense(7, activation='softmax'))
model2.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model2.summary()

и включу вывод прогноза ResNet50:

resnet = ResNet50(include_top=False, pooling='avg')
x_batch = resnet.predict(x_batch)

Тогда точность проверки возрастетпримерно до 85% ... Что происходит?Почему не работает метод ввода изображения?

Обновление:

Эта проблема действительно странная.Если я поменяю ResNet50 на VGG19, он, кажется, будет работать нормально.

Ответы [ 2 ]

1 голос
/ 07 апреля 2019

После долгих поисков я обнаружил, что проблема связана со слоями нормализации партии в ResNet. В VGGNet нет уровней пакетной нормализации, поэтому он работает для этой топологии.

Существует запрос на удаление, чтобы исправить это в Keras здесь , который объясняет более подробно:

Предположим, что мы используем одну из предварительно обученных CNN из Кераса, и мы хотим отрегулировать ее. К сожалению, мы не получаем никаких гарантий, что среднее значение и дисперсия нашего нового набора данных в слоях BN будут аналогичны исходному набору данных. В результате, если мы настроим верхние слои, их веса будут скорректированы на среднее значение / дисперсию нового набора данных. Тем не менее, во время вывода верхние слои будут получать данные, которые масштабируются с использованием среднего значения / дисперсии исходного набора данных. Это несоответствие может привести к снижению точности.

Это означает, что уровни BN подстраиваются под тренировочные данные, однако, когда проверка выполняется, используются исходные параметры уровней BN. Из того, что я могу сказать, исправление состоит в том, чтобы позволить замороженным слоям BN использовать обновленное среднее значение и отклонение от обучения.

Обходной путь - предварительно рассчитать выход ResNet. Фактически это значительно сокращает время обучения, поскольку мы не повторяем эту часть расчета.

0 голосов
/ 06 апреля 2019

вы можете попробовать:

Res = keras.applications.resnet.ResNet50(include_top=False, 
              weights='imagenet',  input_shape=(IMG_SIZE , IMG_SIZE , 3 ) )


    # Freeze the layers except the last 4 layers
for layer in vgg_conv.layers  :
   layer.trainable = False

# Check the trainable status of the individual layers
for layer in vgg_conv.layers:
    print(layer, layer.trainable)

# Vector input
model2 = Sequential()
model2.add(Res)
model2.add(Flatten())
model2.add(Dropout(0.05 ))
model2.add(Dense(512, activation='relu'))
model2.add(Dropout(0.15))
model2.add(Dense(512, activation='relu'))
model2.add(Dense(7, activation='softmax'))
model2.compile(optimizer='adam', loss='categorical_crossentropy', metrics =(['accuracy'])
model2.summary()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...