Перекрестная проверка с CNN - PullRequest
1 голос
/ 30 апреля 2019

Я хотел бы знать, выполняет ли мой код то, что я хочу делать;Чтобы дать вам некоторое представление, я внедряю CNN для классификации изображений.Я пытаюсь использовать перекрестную проверку для сравнения моей другой архитектуры нейронной сети

здесь код:


def create_model():
    model = Sequential()
    model.add(Conv2D(24,kernel_size=3,padding='same',activation='relu',
            input_shape=(96,96,1)))
    model.add(MaxPool2D())
    model.add(Conv2D(48,kernel_size=3,padding='same',activation='relu'))
    model.add(MaxPool2D())
    model.add(Conv2D(64,kernel_size=3,padding='same',activation='relu'))
    model.add(MaxPool2D())
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(256, activation='relu'))
    model.add(Dense(12, activation='softmax'))
    model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
    return model
model = KerasClassifier(build_fn=create_model, epochs=5, batch_size=20, verbose=1) 
# 3-Fold Crossvalidation
kfold = KFold(n_splits=3, shuffle=True, random_state=2019) 
results = cross_val_score(model, train_X, train_Y_one_hot, cv=kfold)

model.fit(train_X, train_Y_one_hot,validation_data=(valid_X, valid_label),class_weight=class_weights)
y_pred = model.predict(test_X)

test_eval = model.evaluate(test_X, y_pred, verbose=0)

Я нашел часть для перекрестной проверки в Интернете.Но у меня есть некоторая проблема, чтобы понять это.

Мой вопрос: 1 => Могу ли я использовать перекрестную проверку для повышения моей точности?Например, я запускаю 10 раз, когда моя нейронная сеть и моя модель получают вес, при котором достигается лучшая точность

2 => Если я хорошо понимаю, в приведенном выше коде результаты запустите мою CNN 3время и покажи мне точность.Но когда я использую model.fit , модель запускается только один раз;Я прав?

Спасибо за вашу помощь

Ответы [ 2 ]

1 голос
/ 30 апреля 2019

Здесь есть два ключевых термина, с которыми вы должны ознакомиться:

  1. Гиперпараметры
  2. Параметры

Гиперпараметры управляют общей архитектурой модели,Это то, что контролирует программист или ученый.В случае CNN это относится к числу слоев, их конфигурациям, активациям, оптимизаторам и т. Д. Для простой модели полиномиальной регрессии это будет степень полинома.

Параметры относятся к фактическим значениям весов или коэффициентов, которыми заканчивается модель после того, как она решает оптимизацию с использованием градиентного спуска или любого другого метода, который вы используете.В CNN это будет матрица весов для каждого слоя.Для полиномиальной регрессии это были бы коэффициенты и смещение.

Перекрестная проверка используется для поиска наилучшего набора гиперпараметров.Наилучший набор параметров получен оптимизатором (градиентный спуск, Адам и т. Д.) Для заданного набора гиперпараметров и данных.

Чтобы ответить на ваши вопросы:

  1. Вы должны выполнить перекрестную проверку несколько раз, каждый раз с другой конфигурацией гиперпараметра (сетевая архитектура).Это единственное, что вы можете контролировать.В конце вы выбираете лучшую архитектуру, основанную на точности.Веса модели были бы разными для каждого сгиба, но поиск оптимальных весов - это работа оптимизатора, а не ваша.

  2. Да.В 3-кратном резюме модель тренируется 3 раза и оценивается 3 раза.Когда вы делаете model.fit, вы делаете прогнозы один раз для нового набора данных.

1 голос
/ 30 апреля 2019
  1. Не совсем, перекрестная проверка - это больше способ предотвратить переоснащение / не быть сбитым с толку ненормальными результатами, полученными из плохо разделенного набора данных -> получить превосходную оценку производительности вашей модели. Если вы хотите настроить Гиперпараметры вашей модели, вам лучше использовать sklearn.model_selection.GridSearchCV / sklearn.model_selection.RandomSearchCV

  2. при выполнении cross_val_score за каждый поезд / тест sklearn делает fit затем predict / evaluate, поэтому для каждого нового экземпляра модели, у вас есть 1 fit, затем 1 predict / evaluate; В противном случае ваша перекрестная проверка недействительна, поскольку она зависит от подгонки к предыдущему набору данных (и, возможно, от тестовых данных!)

...