Как я могу вычислить или изобразить ошибку каждого класса в CNN? - PullRequest
0 голосов
/ 06 июля 2019

Я использовал этот код для своего проекта, но я хочу вычислить или изобразить ошибку каждого класса. У меня есть 6 классов. как я могу это сделать?

def plot_history(net_history):
    history = network_history.history
    losses = history['loss']
    accuracies = history['acc']
    plt.xlabel('Epochs')
    plt.ylabel('loss')
    plt.plot(losses)

    plt.figure()
    plt.xlabel('Epochs')
    plt.ylabel('accuracy')
    plt.plot(accuracies)

Создать мою модель

myinput = layers.Input(shape=(100,200))
conv1 = layers.Conv1D(16, 3, activation='relu', padding='same', strides=2)(myinput)
conv2 = layers.Conv1D(32, 3, activation='relu', padding='same', strides=2)(conv1)
flat = layers.Flatten()(conv2)
out_layer = layers.Dense(6, activation='softmax')(flat)

mymodel = Model(myinput, out_layer)
mymodel.summary()
mymodel.compile(optimizer=keras.optimizers.Adam(), 
loss=keras.losses.categorical_crossentropy, metrics=['accuracy'])

тренируй мою модель

network_history = mymodel.fit(X_train, Y_train, batch_size=128,epochs=5, validation_split=0.2)
plot_history(network_history)

Оценка

test_loss, test_acc = mymodel.evaluate(X_test, Y_test)

test_labels_p = mymodel.predict(X_test)

Ответы [ 2 ]

0 голосов
/ 07 июля 2019

Вы должны обучить ее как задачу бинарной классификации, а затем использовать этот код для построения кривых обучения для разных классов:

plt.plot(network_history.history['loss'])
plt.plot(network_history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
0 голосов
/ 06 июля 2019

Простой способ оценки классификатора - это classification_report в scikit-learn:

from sklearn.metrics import classification_report

....

# Actual predictions here, not just probabilities
pred = numpy.round(mymodel.predict(X_test))
print(classification_report(Y_test, pred))

, где Y_test - список горячих векторов.

Это будетпоказать вам точность, вспомнить и мера f1 для каждого класса.Недостатком является то, что он только учитывает, был ли прогноз верным или неправильным, и не учитывает достоверность модели.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...