Как получить мультикласс roc_au c в перекрестной проверке в sklearn? - PullRequest
0 голосов
/ 24 марта 2020

У меня проблема с классификацией, когда я хочу получить значение roc_auc, используя cross_validate в sklearn. Мой код выглядит следующим образом.

from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features.
y = iris.target

from sklearn.ensemble import RandomForestClassifier
clf=RandomForestClassifier(random_state = 0, class_weight="balanced")

from sklearn.model_selection import cross_validate
cross_validate(clf, X, y, cv=10, scoring = ('accuracy', 'roc_auc'))

Однако я получаю следующую ошибку.

ValueError: multiclass format is not supported

Обратите внимание, что я выбрал roc_auc, в частности, что он поддерживает классификацию binary и multiclass, как упоминалось в: https://scikit-learn.org/stable/modules/model_evaluation.html

У меня тоже есть набор данных бинарной классификации. Пожалуйста, дайте мне знать, как устранить эту ошибку.

Я с удовольствием предоставлю более подробную информацию, если это необходимо.

1 Ответ

2 голосов
/ 24 марта 2020

По умолчанию multi_class='raise', поэтому вам нужно явно , чтобы изменить это.

Из документов :

multi_class {'повысить', 'ovr', 'ovo'}, по умолчанию = 'повысить'

Только мультикласс. Определяет тип используемой конфигурации. Значение по умолчанию вызывает ошибку, поэтому либо 'ovr', либо 'ovo' должны быть переданы явно.

'ovr':

Вычисляет AU C каждого класса против остальных [ 3] [4]. Это относится к случаю мультикласса так же, как и к случаю с несколькими метками. Чувствителен к дисбалансу класса, даже когда average == 'macro', потому что дисбаланс класса влияет на состав каждой из групп «отдыха».

'ovo':

Вычисляет среднее AU C для всех возможные попарные комбинации классов [5]. Нечувствителен к дисбалансу класса, когда average == 'macro'.


Решение:

Использование make_scorer ( документы ):

from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features.
y = iris.target

from sklearn.ensemble import RandomForestClassifier
clf=RandomForestClassifier(random_state = 0, class_weight="balanced")

from sklearn.metrics import make_scorer
from sklearn.metrics import roc_auc_score

myscore = make_scorer(roc_auc_score, multi_class='ovo',needs_proba=True)

from sklearn.model_selection import cross_validate
cross_validate(clf, X, y, cv=10, scoring = myscore)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...