Как отобразить путь дерева решений для тестовых образцов? - PullRequest
0 голосов
/ 27 апреля 2019

Я использую DecisionTreeClassifier из scikit-learn, чтобы классифицировать некоторые мультиклассовые данные.Я нашел много постов, описывающих, как отображать путь дерева решений, например здесь , здесь и здесь .Тем не менее, все они описывают, как отобразить дерево для обученных данных.Это имеет смысл, поскольку для export_graphviz требуется только подобранная модель.

Мой вопрос заключается в том, как визуализировать дерево на тестовых образцах (предпочтительно с помощью export_graphviz).Т.е. после подгонки модели с clf.fit(X[train], y[train]), а затем прогнозирования результатов для тестовых данных на clf.predict(X[test]), я хочу визуализировать путь принятия решения, используемый для прогнозирования выборок X[test].Есть ли способ сделать это?

Редактировать:

Я вижу, что путь может быть напечатан с использованием solution_path .Если есть способ получить вывод DOT с export_graphviz для его отображения, это было бы замечательно.

1 Ответ

3 голосов
/ 27 апреля 2019

Чтобы получить путь, выбранный для конкретной выборки в дереве решений, вы можете использовать decision_path.Он возвращает разреженную матрицу с путями принятия решений для предоставленных выборок.

Затем эти пути принятия решений можно использовать для окрашивания / маркировки дерева, сгенерированного с помощью pydot.Это требует перезаписи цвета и метки (что приводит к небольшому количеству уродливого кода).

Примечания

  • decision_path могут брать образцы из обученияустановить или новые значения
  • вы можете сойти с ума с цветами и изменить цвет в зависимости от количества образцов или любой другой визуализации может потребоваться

Пример

В приведенном ниже примере посещенный узел окрашен в зеленый цвет, все остальные узлы - в белый.

enter image description here

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)

# empty all nodes, i.e.set color to white and number of samples to zero
for node in graph.get_node_list():
    if node.get_attributes().get('label') is None:
        continue
    if 'samples = ' in node.get_attributes()['label']:
        labels = node.get_attributes()['label'].split('<br/>')
        for i, label in enumerate(labels):
            if label.startswith('samples = '):
                labels[i] = 'samples = 0'
        node.set('label', '<br/>'.join(labels))
        node.set_fillcolor('white')

samples = iris.data[129:130]
decision_paths = clf.decision_path(samples)

for decision_path in decision_paths:
    for n, node_value in enumerate(decision_path.toarray()[0]):
        if node_value == 0:
            continue
        node = graph.get_node(str(n))[0]            
        node.set_fillcolor('green')
        labels = node.get_attributes()['label'].split('<br/>')
        for i, label in enumerate(labels):
            if label.startswith('samples = '):
                labels[i] = 'samples = {}'.format(int(label.split('=')[1]) + 1)

        node.set('label', '<br/>'.join(labels))

filename = 'tree.png'
graph.write_png(filename)
...