Как вернуть массив ложных срабатываний из путаницы в scikit-learn? - PullRequest
0 голосов
/ 30 апреля 2018

Я работаю над созданием двоичного классификатора в scikit-learn , который будет классифицировать текстовые обзоры. Базовый рабочий процесс включает в себя следующее:

#Splitting the data into training and testing sets.
X_train, X_test, y_train, y_test = train_test_split(X, y, 
test_size=0.20, random_state=42)

#Instantiate a model
nb = MultinomialNB()

#Train the model.
nb.fit(X_train, y_train)

#Make predictions using the trained model
y_pred_class = nb.predict(X_test)

#View confusion matrix
confusion_matrix(y_test, y_pred_class)

#Output of confusion matrix
array([[295,  13],
      [ 80,  70]])

На основе матрицы путаницы существует 13 ложных срабатываний и 80 ложных отрицательных.

Я хочу видеть 13 текстовых обзоров, которые классифицируются как ложноположительные.

Я следил за этим постом , чтобы посмотреть, смогу ли я получить список из 13 пунктов, которые классифицируются как ложные срабатывания.

Однако, когда я запускаю следующее:

X_test[y_test != y_pred_class]

Я получаю следующий объект:

<458x758 sparse matrix of type '<class 'numpy.float64'>'
with 16890 stored elements in Compressed Sparse Row format>

Похоже, что возвращаются все значения в X_test (всего 458 записей). Я ожидал объект, который был менее 458 записей.

Я также ожидал увидеть текстовые данные X_test в отличие от объекта.

У меня такой вопрос:

Как я могу вернуть 13 записей из X_test, которые были ошибочно классифицированы как ложные срабатывания? Я ищу вывод, похожий на приведенный ниже пример.

2175    This has to be the worst restaurant in terms o...
1781    If you like the stuck up Scottsdale vibe this ...
2674    I'm sorry to be what seems to be the lone one ...
Name: text, dtype: object

1 Ответ

0 голосов
/ 30 апреля 2018

Для ложных срабатываний необходимо также проверить значения, которые равны 1 в y_pred_class, в дополнение к y_test != y_pred_class.

Попробуйте это:

import numpy as np
false_positives = np.logical_and(y_test != y_pred_class, y_pred_class == 1)

X_test[false_positives]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...