Исправление 100% точности с помощью DecisionTreeClassifier в scikit-learn - PullRequest
1 голос
/ 06 августа 2020

Я пытаюсь использовать дерево решений для классификации и получить 100% точность.

Это общая проблема, описанная здесь и здесь . И во многих других вопросах.

Данные здесь .

Два лучших предположения:

  • Я неправильно разделяю данные
  • Мой набор данных слишком несбалансирован

Что не так с моим кодом?

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import cross_val_score
import sklearn.model_selection as cv
from sklearn.metrics import mean_squared_error as MSE
from sklearn.model_selection import train_test_split 
from sklearn import metrics
from sklearn.metrics import confusion_matrix 
from sklearn.metrics import accuracy_score 

# Split data
Y = starbucks.iloc[:, 4]
X = starbucks.loc[:, starbucks.columns != 'offer_completed']

# Splitting the dataset into train and test 
X_train, X_test, y_train, y_test = train_test_split(X, Y, 
                                                    test_size=0.3,
                                                    random_state=100) 

# Creating the classifier object 
clf_gini = DecisionTreeClassifier(criterion = "gini", 
                                  random_state = 100, 
                                  max_depth = 3, 
                                  min_samples_leaf = 5) 

# Performing training 
clf_gini.fit(X_train, y_train)

# Predicton on test with giniIndex 
y_pred = clf_gini.predict(X_test) 
print("Predicted values:") 
print(y_pred) 

print("Confusion Matrix: ", confusion_matrix(y_test, y_pred)) 

print ("Accuracy : ", accuracy_score(y_test, y_pred)*100) 

print("Report : ", classification_report(y_test, y_pred)) 

y_pred_gini = prediction(X_test, clf_gini) 
cal_accuracy(y_test, y_pred_gini) 


Predicted values:
[0. 0. 0. ... 0. 0. 0.]
Confusion Matrix:  [[36095     0]
                    [    0  8158]]
Accuracy :  100.0

Когда я печатаю X, он показывает мне, что offer_completed было удалено.

X.dtypes

offer_received               int64
offer_viewed               float64
time_viewed_received       float64
time_completed_received    float64
time_completed_viewed      float64
transaction                float64
amount                     float64
total_reward               float64
age                        float64
income                     float64
male                         int64
membership_days            float64
reward_each_time           float64
difficulty                 float64
duration                   float64
email                      float64
mobile                     float64
social                     float64
web                        float64
bogo                       float64
discount                   float64
informational              float64

1 Ответ

2 голосов
/ 06 августа 2020

Подбирая модель и проверяя важность функций, вы можете увидеть, что все они нули, за исключением total_reward. Затем, вложив такой столбец, вы получите:

df.groupby(target)['total_reward'].describe()
    count   mean    std    min   25%    50%   75%    max
0   119995  0.0     0.0    0.0   0.0    0.0   0.0    0.0
1   27513   5.74    4.07   2.0   3.0    5.0   10.0   40.0

Вы можете видеть, что для цели 0 total_reward всегда равно нулю, иначе имеет значение всегда больше 0. Вот ваша утечка.

Поскольку могут быть и другие утечки, а проверять каждый столбец утомительно, мы можем использовать своего рода «предсказательную силу» только для каждой функции:

acc_df = pd.DataFrame(columns=['col', 'acc'], index=range(len(X.columns)))

for i, c in enumerate(X.columns):

    clf = DecisionTreeClassifier(criterion = "gini", 
                                 random_state = 100, 
                                 max_depth = 3, 
                                 min_samples_leaf = 5) 
    
    clf.fit(X_train[c].to_numpy()[:, None], y_train)
    
    y_pred = clf.predict(X_test[c].to_numpy()[:, None])
    acc_df.iloc[i] = [c, accuracy_score(y_test, y_pred)*100]


acc_df.sort_values('acc',ascending=False)
                 col      acc
8       total_reward      100
4     completed_time  99.8848
13  reward_each_time  89.3205
14        difficulty  89.3205
15          duration  89.3205
21          discount  86.4054
19               web   85.088
20              bogo  84.4801
3        viewed_time  84.4056
2       offer_viewed  84.3491
18            social  83.3525
1      received_time  83.0497
7             amount  82.5436
0     offer_received  81.7526
16             email  81.7526
17            mobile  81.6464
11              male  81.5651
10            income  81.5651
9                age  81.5651
6   transaction_time  81.5651
5        transaction  81.5651
22     informational  81.5651
12   membership_days  81.5561

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...