На основе http://scikit -learn.org / stable / auto_examples / tree / plot_unveil_tree_structure.html
Предполагая, что вероятности равны доле классов в каждом узле, например, если лист содержит68 экземпляров с классом 0 и 15 с классом 1 (то есть value
в tree_
составляет [68,15]) вероятности составляют [0.81927711, 0.18072289]
.
Генерация простого дерева, 4 объекта, 2 класса:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from sklearn.cross_validation import train_test_split
from sklearn.tree import _tree
X, y = make_classification(n_informative=3, n_features=4, n_samples=200, n_redundant=1, random_state=42, n_classes=2)
feature_names = ['X0','X1','X2','X3']
Xtrain, Xtest, ytrain, ytest = train_test_split(X,y, random_state=42)
clf = DecisionTreeClassifier(max_depth=2)
clf.fit(Xtrain, ytrain)
Визуализируйте это:
from sklearn.externals.six import StringIO
from sklearn import tree
import pydot
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue()) [0]
graph.write_jpeg('1.jpeg')
Создайте функцию для печати условия для одного экземпляра:
node_indicator = clf.decision_path(Xtrain)
n_nodes = clf.tree_.node_count
feature = clf.tree_.feature
threshold = clf.tree_.threshold
leave_id = clf.apply(Xtrain)
def value2prob(value):
return value / value.sum(axis=1).reshape(-1, 1)
def print_condition(sample_id):
print("WHEN", end=' ')
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
node_indicator.indptr[sample_id + 1]]
for n, node_id in enumerate(node_index):
if leave_id[sample_id] == node_id:
values = clf.tree_.value[node_id]
probs = value2prob(values)
print('THEN Y={} (probability={}) (values={})'.format(
probs.argmax(), probs.max(), values))
continue
if n > 0:
print('&& ', end='')
if (Xtrain[sample_id, feature[node_id]] <= threshold[node_id]):
threshold_sign = "<="
else:
threshold_sign = ">"
if feature[node_id] != _tree.TREE_UNDEFINED:
print(
"%s %s %s" % (
feature_names[feature[node_id]],
#Xtrain[sample_id,feature[node_id]] # actual value
threshold_sign,
threshold[node_id]),
end=' ')
Вызовите его в первой строке:
>>> print_condition(0)
WHEN X1 > -0.2662498950958252 && X0 > -1.1966443061828613 THEN Y=1 (probability=0.9672131147540983) (values=[[ 2. 59.]])
Вызовите его во всех строках, где прогнозируемое значение равно нулю:
[print_condition(i) for i in (clf.predict(Xtrain) == 0).nonzero()[0]]