Поиск подходящей архитектуры и параметров модели CNN - PullRequest
1 голос
/ 16 февраля 2020

В настоящее время я создаю модель CNN, которая классифицирует, является ли шрифт Arial, Verdana, Times New Roman и Georgia. В целом есть классы 16, так как я подумал также определить, является ли шрифт regular, bold, italics или bold italics. Итак, 4 fonts * 4 styles = 16 classes.

Данные, которые я использовал при обучении, следующие:

 Training data set : 800 image patches of 256 * 256 dimension (50 for each class)
 Validation data set : 320 image patches of 256 * 256 dimension (20 for each class)
 Testing data set : 160 image patches of 256 * 256 dimension (10 for each class)

Ниже приведен пример скриншота моих данных:

enter image description here

Ниже приведен мой исходный код:

import numpy as np
import keras
from keras import backend as K
from keras.models import Sequential
from keras.layers import Activation
from keras.layers.core import Dense, Flatten
from keras.optimizers import Adam
from keras.metrics import categorical_crossentropy
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import *
from matplotlib import pyplot as plt
import itertools
import matplotlib.pyplot as plt
import pickle


image_width = 256
image_height = 256

train_path = 'font_model_data/train'
valid_path =  'font_model_data/valid'
test_path = 'font_model_data/test'


train_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(image_width, image_height), classes=['1','2','3','4', '5', '6', '7', '8', '9', '10', '11', '12','13', '14', '15', '16'], batch_size = 16)
valid_batches = ImageDataGenerator().flow_from_directory(valid_path, target_size=(image_width, image_height), classes=['1','2','3','4', '5', '6', '7', '8', '9', '10', '11', '12','13', '14', '15', '16'], batch_size = 16)
test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(image_width, image_height), classes=['1','2','3','4', '5', '6', '7', '8', '9', '10', '11', '12','13', '14', '15', '16'], batch_size = 160)


 imgs, labels = next(train_batches)

 #CNN model
 model = Sequential([
     Conv2D(32, (3,3), activation='relu', input_shape=(image_width, image_height, 3)),
     Flatten(),
     Dense(16, activation='softmax'),
 ])

 print(model.summary())

 model.compile(Adam(lr=.0001),loss='categorical_crossentropy', metrics=['accuracy'])
 model.fit_generator(train_batches, steps_per_epoch = 50, validation_data= valid_batches, validation_steps = 20, epochs = 1, verbose = 2)

 model_pickle = open('cnn_font_model.pickle', 'wb')
 pickle.dump(model, model_pickle)
 model_pickle.close()
 print('Training Done.')

 test_imgs, test_labels = next(test_batches)

 predictions = model.predict_generator(test_batches, steps = 160, verbose = 2)
 print(predictions)

Кто-нибудь может подсказать, как мне узнать правильную архитектуру и параметры сети для получения оптимальной точности? Как начать настройку сети?

1 Ответ

1 голос
/ 16 февраля 2020

Прежде чем перейти к выбору сети, вам нужно сегментировать изображение мозаичного изображения в субтитры с символами и использовать следующую архитектуру ...

# Initialising the CNN
classifier = Sequential()

# Step 1 - Convolution
classifier.add(Conv2D(32, (3, 3), input_shape = (64, 64, 3), activation = 'relu'))

# Step 2 - Pooling
classifier.add(MaxPooling2D(pool_size = (2, 2)))

# Adding a second convolutional layer
classifier.add(Conv2D(32, (3, 3), activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))

# Step 3 - Flattening
classifier.add(Flatten())

# Step 4 - Full connection
classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dense(units = 1, activation = 'sigmoid'))

# Compiling the CNN
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
classifier.fit_generator(training_set,
                     steps_per_epoch = XXX,
                     epochs = XX,
                     validation_data = test_set,
                     validation_steps = XXX)
from keras.models import load_model
classifier.save('your_classifier.h5')
...