Как я могу получить один столбец в моем питоне (pandas) Dataframe, чтобы увидеть все правила моего дерева решений, которые привели меня к моему результату? - PullRequest
0 голосов
/ 19 июня 2019

Я работаю над деревом решений (классификатором) для sklearn, и оно работает хорошо, я могу визуализировать дерево и прогнозировать свой класс. Но я хотел бы создать один столбец (в моем фрейме данных pandas), который является путем для получения моего результата в дереве. Я имею в виду, я хотел бы объединение всех правил, чтобы получить мой результат, как: - Белый = Ложь, Черный = Ложь, Вес = 1, Цена = 5. У вас есть идеи, пожалуйста?

1 Ответ

1 голос
/ 20 июня 2019

На основе примера здесь вы можете создать свое объяснение применяемых правил.

  • estimator.decision_path дает вам узлы, которые следуют, чтобы получить результат
  • is_leaves - это массив, который хранит для каждого узла, если это лист, то есть терминал, (True) или ветвь / решение (False)
  • Затем можно выполнить итерациюnode_indicator для получения посещенных узлов
  • Для каждого узла вы можете получить threshold и соответствующую feature
  • Наконец apply функцию для вашего фрейма данных ивы сделали.

    def get_decision_path(estimator, feature_names, sample, precision=2, is_leaves=None):
        if is_leaves is None:
            is_leaves = get_leaves(estimator)
        feature = estimator.tree_.feature
        threshold = estimator.tree_.threshold
    
        text = []
    
        node_indicator = estimator.decision_path([sample])
        node_index = node_indicator.indices[node_indicator.indptr[0]:
                                            node_indicator.indptr[1]]
    
        for node_id in node_index:
            if is_leaves[node_id]:
                break
    
            if sample[feature[node_id]] <= threshold[node_id]:
                threshold_sign = "<="
            else:
                threshold_sign = ">"
    
            text.append('{}: {} {} {}'.format(feature_names[feature[node_id]],
                                              sample[feature[node_id]],
                                              threshold_sign,
                                              round(threshold[node_id], precision)))
    
        return '; '.join(text)
    
    def get_leaves(estimator):
        n_nodes = estimator.tree_.node_count
        children_left = estimator.tree_.children_left
        children_right = estimator.tree_.children_right
        is_leaves = np.zeros(shape=n_nodes, dtype=bool)
        stack = [(0, -1)]
        while len(stack) > 0:
            node_id, parent_depth = stack.pop()
    
            if children_left[node_id] != children_right[node_id]:
                stack.append((children_left[node_id], parent_depth + 1))
                stack.append((children_right[node_id], parent_depth + 1))
            else:
                is_leaves[node_id] = True
        return is_leaves
    

Пример

print(get_decision_path(estimator, 
                        ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], 
                        [6.6, 3.0 , 4.4, 1.4]))

'petal width (cm): 1.4 > 0.8; petal length (cm): 4.4 <= 4.95; petal width (cm): 1.4 <= 1.65'

Полныйкод

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
from sklearn import tree
import pydotplus
from IPython.core.display import HTML, display

def get_decision_path(estimator, feature_names, sample, precision=2, is_leaves=None):
    if is_leaves is None:
        is_leaves = get_leaves(estimator)
    feature = estimator.tree_.feature
    threshold = estimator.tree_.threshold

    text = []

    node_indicator = estimator.decision_path([sample])
    node_index = node_indicator.indices[node_indicator.indptr[0]:
                                        node_indicator.indptr[1]]

    for node_id in node_index:
        if is_leaves[node_id]:
            break

        if sample[feature[node_id]] <= threshold[node_id]:
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        text.append('{}: {} {} {}'.format(feature_names[feature[node_id]],
                                          sample[feature[node_id]],
                                          threshold_sign,
                                          round(threshold[node_id], precision)))

    return '; '.join(text)


def get_leaves(estimator):
    n_nodes = estimator.tree_.node_count
    children_left = estimator.tree_.children_left
    children_right = estimator.tree_.children_right
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)
    stack = [(0, -1)]
    while len(stack) > 0:
        node_id, parent_depth = stack.pop()

        if children_left[node_id] != children_right[node_id]:
            stack.append((children_left[node_id], parent_depth + 1))
            stack.append((children_right[node_id], parent_depth + 1))
        else:
            is_leaves[node_id] = True
    return is_leaves

# prepare data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target

X = df.iloc[:, 0:4].to_numpy()
y = df.iloc[:, 4].to_numpy()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# create decision tree
estimator = DecisionTreeClassifier(max_leaf_nodes=5, random_state=0)
estimator.fit(X_train, y_train)

# visualize decision tree
dot_data = tree.export_graphviz(estimator, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
svg = graph.create_svg()
display(HTML(svg.decode('utf-8')))

# add explanation to data frame
is_leaves = get_leaves(estimator)
df['explanation'] = df.apply(lambda row: get_decision_path(estimator, df.columns[0:4], row[0:4], is_leaves=is_leaves), axis=1)

df.sample(5, axis=0, random_state=42)

enter image description here enter image description here

...