Для этого вы можете вручную пройтись по подгонянному дереву, получив доступ к свойствам, недоступным через publi c api.
Сначала давайте возьмем подгонянное дерево, используя набор данных "iris":
import numpy as np # linear algebra
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
data = load_iris()
clf = DecisionTreeClassifier(max_depth=3).fit(data['data'],data['target'])
Давайте визуализируем это дерево, в первую очередь для отладки нашей финальной программы:
plt.figure(figsize=(10,8))
plot_tree(clf,feature_names=data['feature_names'],class_names=data['target_names'],filled=True);
Что выводит в моем случае:
Теперь основной часть. Из этой ссылки мы знаем, что -
Двоичное дерево "tree_" представляется в виде ряда параллельных массивов. I-й элемент каждого массива содержит информацию об узле i
.
. Нам нужны массивы feature
, value
, threshold
и два children_*
. Итак, начиная с root (i=0
), мы сначала собираем функцию и порог для каждого посещаемого узла, запрашиваем у пользователя значение этой конкретной функции и переходим влево или вправо, сравнивая данное значение с порогом. Когда мы достигаем листа, мы находим самый частый класс в этом листе, и это заканчивает наш l oop.
tree = clf.tree_
node = 0 #Index of root node
while True:
feat,thres = tree.feature[node],tree.threshold[node]
print(feat,thres)
v = float(input(f"The value of {data['feature_names'][feat]}: "))
if v<=thres:
node = tree.children_left[node]
else:
node = tree.children_right[node]
if tree.children_left[node] == tree.children_right[node]: #Check for leaf
label = np.argmax(tree.value[node])
print("We've reached a leaf")
print(f"Predicted Label is: {data['target_names'][label]}")
break
Пример такого запуска для вышеприведенного дерева:
3 0.800000011920929
The value of petal width (cm): 1
3 1.75
The value of petal width (cm): 1.5
2 4.950000047683716
The value of petal length (cm): 5.96
We've reached a leaf
Predicted Label is: virginica