Ошибка матрицы путаницы в графике с OpenCV Python - PullRequest
0 голосов
/ 18 апреля 2019

Я использую opencv-python confusion_matrix из sklearn.metrics, чтобы построить матрицу путаницы для своей задачи.Это работает, когда я строю матрицу скорости восстановления (установите axis=1 в моем коде)Но когда я хочу построить матрицу точности (установите axis=0 в моем коде), номер моей матрицы неверен.Вот мои матрицы:

enter image description here

Например, в матрице путаницы средней точности 0,8 (первый ряд, первый столбец) плюс 0,02 (второй ряд, первыйстолбец) должен быть 1.0, но это не так.Можете ли вы сказать мне, где это не так?

А это мой код определения функции

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues,
                          axis=1): 

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=axis)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()

Это код о построении рисунка:

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_true, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_name,
                      title='Confusion matrix, without normalization')

# Plot normalized Precision confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_name, normalize=True,
                      title='Normalized Precision confusion matrix',
                      axis=0)

# Plot normalized Recall confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_name, normalize=True,
                      title='Normalized Recall confusion matrix',
                      axis=1)

plt.show()

Спасибоочень за вашу помощь!

...