Почему sklearn.metrics.confusion_matrix и sklearn.metrics.plot_confusion_matrix имеют несовместимые определения функций? - PullRequest
2 голосов
/ 17 апреля 2020

Я использую sklearn и заметил, что аргументы sklearn.metrics.plot_confusion_matrix и sklearn.metrics.confusion_matrix противоречивы. plot_confusion_matrix использует estimator и X для построения y_pred, в то время как confusion_matrix имеет y_pred в качестве аргумента напрямую.

Что может быть причиной этого несоответствия?

Определения частичной функции:

  • sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, ...) [где X должно быть X_test]
  • sklearn.metrics.confusion_matrix(y_true, y_pred, ...)

Источники:

1 Ответ

2 голосов
/ 18 апреля 2020

Да, вы правы, что для этого нет согласованного дизайна API, но идет обсуждение этой проблемы здесь .

Один быстрый обходной путь - ConfusionMatrixDisplay.

пример:

from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

X, y = make_classification(random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)

clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0))
clf.fit(X_train, y_train)

from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred)

cm_display = ConfusionMatrixDisplay(cm, [0,1]).plot()
...