Как найти кривую RO C и оценку AU C этой модели CNN (керас) - PullRequest
0 голосов
/ 30 апреля 2020

Мой код CNN в керасе выглядит следующим образом:

from keras.models import Sequential
from keras.layers import Convolution2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
from keras.layers import Dropout

classifier = Sequential()
#1st Conv layer
classifier.add(Convolution2D(64, (9, 9), input_shape=(64, 64, 3), activation='relu'))
classifier.add(MaxPooling2D(pool_size=(4,4)))
#2nd Conv layer
classifier.add(Convolution2D(32, (3, 3), activation='relu'))
classifier.add(MaxPooling2D(pool_size=(2,2)))

#Flattening
classifier.add(Flatten())

# Step 4 - Full connection
classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dropout(0.1))
classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dropout(0.2))
classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dense(units = 2, activation = 'softmax'))

classifier.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

#Fitting dataset

from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator(rescale = 1./255)

training_set = train_datagen.flow_from_directory('dataset/training_set',
                                                 target_size = (64, 64),
                                                 batch_size = 32,
                                                 class_mode = 'categorical')

test_set = test_datagen.flow_from_directory('dataset/test_set',
                                            target_size = (64, 64),
                                            batch_size = 32,
                                            class_mode = 'categorical')

classifier.fit_generator(
        training_set,
        steps_per_epoch=(1341+3875)/32,
        epochs=15,
        validation_data=test_set,
        validation_steps=(234+390)/32)

Где бы я ни увидел использование roc_curve из sklearn.metrics, он принимает такие параметры, как x_train, y_train, x_test, y_test, которые, как я знаю, могут быть pandas DataFrames, но в моем случае это не так. Как построить график RO C и получить оценку AU C для обучения модели для CNN, как здесь?

Ответы [ 2 ]

1 голос
/ 30 апреля 2020

На самом деле, если взглянуть на документы sklearn.metrics.roc_curve (и почти для каждого показателя sklearn c), они не принимают входные данные вашей модели (изображений) в качестве аргументов, а просто принимают истинные метки и предсказанный ярлык. Таким образом, после того, как вы сделаете вывод на тестовом наборе, который в керасе (здесь я только догадываюсь) выглядит примерно так:

preds = classifier.predict(batch)

Вы вызываете roc_curve как

fpr, tpr = roc_curve(true_labels,preds)

Возможно, вам нужно изменить тип, потому что они тензорные.

РЕДАКТИРОВАТЬ: Я проверил документацию keras на flow_from_directory и выдает итератор над (x,y) = (images,labels), поэтому, если вы хотите выполнить какой-либо анализ после обучения, вы должны получить метки, используя что-то вроде этого:

labels = []
for _,y in test_set:
    labels.extend(list(y))

И если у вас есть только два класса, измените class_mode на binary

0 голосов
/ 01 мая 2020

Я получил это работает. Все, что мне нужно было сделать, это сопоставить тип данных pred, полученный из preds = classifier.predict(test_set), с true_labels, которые я получил из labels = test_set. Preds - это, в основном, numpy .ndarray, содержащий списки отдельных элементов, которые имеют значения np.float32. Преобразование меток в тот же формат и форму заставило работать roc_curve.

Кроме того, мне пришлось добавить третье пороговое значение переменной в fpr, tpr, threshold = roc_curve(true_labels, preds), поэтому значение ValueError отсутствует: слишком много значений, чтобы распаковать сообщение об ошибке.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...