Построение кривой RO C и AU C в кератах для бинарной классификации - PullRequest
0 голосов
/ 21 апреля 2020
#more code before this
datagen = ImageDataGenerator(rescale = 1./255)

train_generator = datagen.flow_from_directory(directory=train_data_dir,
    target_size=(img_width,img_height),
    classes=['ham','spam'],
    class_mode='binary',
    batch_size=16)

validation_generator = datagen.flow_from_directory(directory=valid_data_dir,
    target_size=(img_width,img_height),
    classes=['ham','spam'],
    class_mode='binary',
    batch_size=32)

model =Sequential()

model.add(Conv2D(32,(3,3), input_shape=(img_width, img_height, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Conv2D(32,(3,3), input_shape=(img_width, img_height, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Conv2D(64,(3,3), input_shape=(img_width, img_height, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',optimizer='rmsprop',metrics=['accuracy'])

training = model.fit_generator(generator=train_generator, steps_per_epoch=2048 // 16,epochs=1,validation_data=validation_generator,validation_steps=832//16)


Я попытался напечатать матрицу путаницы для построения кривой AU C и Ro C, добавив этот дополнительный код

Y_pred = model.predict_generator(validation_generator, num_of_test_samples // batch_size+1)
y_pred = np.argmax(Y_pred, axis=1)

print(confusion_matrix(validation_generator.classes, y_pred))

, а затем возвращается ошибка

Traceback (most recent call last):
  File "CNN.py", line 65, in <module>
    print(confusion_matrix(validation_generator.classes, y_pred))
  File "/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py", line 268, in confusion_matrix
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
  File "/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py", line 80, in _check_targets
    check_consistent_length(y_true, y_pred)
  File "/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py", line 212, in check_consistent_length
    " samples: %r" % [int(l) for l in lengths])
ValueError: Found input variables with inconsistent numbers of samples: [100, 196]

Я просто не могу понять ошибку Эй, ребята, это мой код, я просто не могу построить кривую RO C и AU C в кератах, поэтому любые рекомендации о том, как построить RO C и AU C для двоичной классификации, я мог видеть категориальные классификации, но без двоичных классификаций.

...