Модель Keras для ориентации автомобиля - PullRequest
0 голосов
/ 30 сентября 2019

Я новичок в Keras и TensorFlow. Я пытаюсь обучить сверточную нейронную сеть для классификации изображений. У меня есть большое количество изображений транспортных средств, и мне нужно получить классификацию ориентации. Здесь - это подход с использованием гистограммы ориентированных градиентов (HOG) (мне также нужен класс крыши автомобиля, всего девять классов).

Ниже приведен код Python для моего CNN.

model = Sequential([
    Conv2D(32, (5, 5), input_shape=(1536, 2048, 3)),
    MaxPooling2D(pool_size=(2, 2)),
    Flatten(),
    # One output network layer with 9 nodes (corresponding to the 9 final classes/orientations)
    Dense(9, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy'] 
)

model.fit(
    train_images,                 # Training data
    to_categorical(train_labels), # Training targets
    epochs=4,
    batch_size=16
)

Как видите, мой CNN состоит всего из трех слоев. Мой учебный набор данных состоит из выборки из 100 различных транспортных средств, каждый из которых имеет все девять требуемых ориентаций, в общей сложности 900 изображений. С четырьмя эпохами мне удалось достичь 90% точности на этапе обучения, но только ~ 50% точности на этапе тестирования. Эта модель плохо себя ведет в прогнозировании результатов для изображений, которые она никогда не видела раньше.

Моя модель CNN очень проста. Мне пришлось уменьшить количество эпох с 5 до 4, потому что оно начало переоснащаться в пятой эпохе. Мои вопросы: как я могу улучшить свою модель? Является ли мой набор данных достаточно большим? Нужно ли добавлять дополнительные слои в модель?

Заранее спасибо.

Обновление

Вот модель с обзором:

model = Sequential([
    Conv2D(32, (3, 3), input_shape=(224, 224, 3)),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(64, (3, 3)),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(128, (3, 3)),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(128, (3, 3)),
    Conv2D(256, (3, 3)),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(256, (3, 3)),
    MaxPooling2D(pool_size=(2, 2)),

    Flatten(),
    # One output network layer with 9 nodes (corresponding to the 9 final classes/orientations)
    Dense(9, activation='softmax')
])

Ответы [ 2 ]

0 голосов
/ 30 сентября 2019

Первый вопрос:

Я бы использовал лучшую маленькую модель, такую ​​как MobileNets или EfficientNetB0, потому что ваша модель сейчас слишком мала, она даже не будет работать в mnist. Но будьте осторожны, чтобы не использовать слишком большую модель, потому что в небольшом наборе данных может быть проще наложение.

Кроме того, уменьшите размер входного изображения, обычно достаточно 224 * 224, большое изображение приводит к наложению на маленький набор данных.

Самое важное замечание: не забудьте использовать увеличение данных, например, ImageDataGenerator в кератах.

Второй вопрос:

Это не так, я думаю, это будет трудно даже с EfficientNetB0, по крайней мере, вам может потребоваться более 50 изображений на класс для традиционного классификатора изображений. Насколько я знаю, Siamese network или Matching network может дать лучший результат.

Третий вопрос:

Как и в первом случае, ваша сеть даже меньше, чем первый CNN, LeNet, который принимает 32x32 входных изображения.

Обновление:

Используйте это вместо того, чтобы просто складывать больше слоев Conv.

from keras.applications.mobilenet_v2 import MobileNetV2
from keras import layers, models
base_model = MobileNetV2(include_top=False)
x = base_model.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(9, activation='softmax')(x)
model = models.Model(base_model.inputs, x)
0 голосов
/ 30 сентября 2019

Большая разница между тренировочной и тестовой успеваемостью часто является признаком переоснащения. Поскольку ваша модель уже довольно мелкая, вы можете рассмотреть возможность изменения количества объектов, рассчитанных в слое Conv2D. Как это работает с 16 или 24 функциями вместо 32?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...