Если вы посмотрите документацию для roc_curve()
, вы увидите следующее относительно параметра y_score
:
y_score: array, shape = [n_samples]Целевые оценки могут быть либо оценками вероятности положительного класса , доверительными значениями или мерой решения без пороговых значений (как возвращено параметром solution_function для некоторых классификаторов).
Вы можете получить оценки вероятности, используя predict_proba()
метод KNeighborsClassifier
в sklearn.Это возвращает пустой массив с двумя столбцами для двоичной классификации, по одному для отрицательного и положительного классов.Для функции roc_curve()
вы хотите использовать оценки вероятностей положительного класса , поэтому вы можете заменить:
y_scores = cross_val_score(knn_cv, X, y, cv=76)
fpr, tpr, threshold = roc_curve(y_test, y_scores)
на:
y_scores = knn.predict_proba(X_test)
fpr, tpr, threshold = roc_curve(y_test, y_scores[:, 1])
Уведомлениекак вам нужно взять все строки второго столбца с помощью [:, 1]
, чтобы выбрать только оценки вероятностей положительного класса.Вот минимальный воспроизводимый пример с использованием набора данных о раке молочной железы из Висконсина, поскольку у меня нет вашего autoimmune.csv
:
from sklearn.datasets import load_breast_cancer
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
import matplotlib.pyplot as plt
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
knn = KNeighborsClassifier(n_neighbors = 10)
knn.fit(X_train,y_train)
y_scores = knn.predict_proba(X_test)
fpr, tpr, threshold = roc_curve(y_test, y_scores[:, 1])
roc_auc = auc(fpr, tpr)
plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.title('ROC Curve of kNN')
plt.show()
. В результате получается следующая кривая ROC: