Получение значения конечного узла в DecisionTreeRegressor - PullRequest
0 голосов
/ 18 сентября 2018

Я пытался проанализировать DecisionTreeRegressor, в котором я тренировался sklearn.Я нашел http://scikit -learn.org / stable / auto_examples / tree / plot_unveil_tree_structure.html полезной при определении атрибутов, разделяющих каждую ветвь дерева, в частности этот фрагмент кода:

n_nodes = estimator.tree_.node_count
children_left = estimator.tree_.children_left
children_right = estimator.tree_.children_right
feature = estimator.tree_.feature
threshold = estimator.tree_.threshold

# The tree structure can be traversed to compute various properties such
# as the depth of each node and whether or not it is a leaf.
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, -1)]  # seed is the root node id and its parent depth
while len(stack) > 0:
    node_id, parent_depth = stack.pop()
    node_depth[node_id] = parent_depth + 1

    # If we have a test node
    if (children_left[node_id] != children_right[node_id]):
        stack.append((children_left[node_id], parent_depth + 1))
        stack.append((children_right[node_id], parent_depth + 1))
    else:
        is_leaves[node_id] = True


print("The binary tree structure has %s nodes and has "
      "the following tree structure:"
      % n_nodes)
for i in range(n_nodes):
    if is_leaves[i]:
        print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
    else:
        print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "
              "node %s."
              % (node_depth[i] * "\t",
                 i,
                 children_left[i],
                 feature[i],
                 threshold[i],
                 children_right[i],
                 ))

Однако, это не говорит мне значение каждого конечного узла.Если вышеизложенное выводит что-то похожее на это:

The binary tree structure has 7 nodes and has the following tree structure:
node=0 test node: go to node 1 if X[:, 2] <= 1.00764083862 else to node 4.
    node=1 test node: go to node 2 if X[:, 2] <= 0.974808812141 else to node 3.
        node=2 leaf node.
        node=3 leaf node.
    node=4 test node: go to node 5 if X[:, 0] <= -2.90554761887 else to node 6.
        node=5 leaf node.
        node=6 leaf node.

Как узнать значение, которое представляет узел 2, например?

1 Ответ

0 голосов
/ 18 сентября 2018

Метод, который вы ищете: estimator.tree_.value

Давайте сделаем воспроизводимый пример, поскольку тот, на который вы ссылаетесь из документации, предназначен для классификации, а не для регрессии:

import numpy as np
from sklearn.tree import DecisionTreeRegressor

# dummy data
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))

estimator = DecisionTreeRegressor(max_depth=3)
estimator.fit(X, y)

После этого, используя ваш код дословно, получаем:

The binary tree structure has 15 nodes and has the following tree structure: 
node=0 test node: go to node 1 if X[:, 0] <= 3.13275051117 else to node 8. 
       node=1 test node: go to node 2 if X[:, 0] <= 0.513901114464 else to node 5. 
              node=2 test node: go to node 3 if X[:, 0] <= 0.0460066311061 else to node 4. 
                     node=3 leaf node. 
                     node=4 leaf node. 
              node=5 test node: go to node 6 if X[:, 0] <= 2.02933192253 else to node 7. 
                     node=6 leaf node. 
                     node=7 leaf node. 
       node=8 test node: go to node 9 if X[:, 0] <= 3.85022854805 else to node 12. 
              node=9 test node: go to node 10 if X[:, 0] <= 3.42930102348 else to node 11. 
                     node=10 leaf node. 
                     node=11 leaf node. 
              node=12 test node: go to node 13 if X[:, 0] <= 4.68025827408 else to node 14. 
                     node=13 leaf node. 
                     node=14 leaf node.

Теперь estimator.tree_.value содержит значения для всех узлов дерева (здесь 15):

len(estimator.tree_.value)
# 15

и для получения, например, значения для узла # 3, мы спрашиваем

estimator.tree_.value[3]
# array([[-1.1493464]])

Подробное объяснение содержания value (включая нетерминальные узлы) см. В моих ответах в

  1. интерпретация выходных данных Graphviz для регрессии дерева решений (для регрессии) и

  2. Что делает Scikit-learn DecisionTreeClassifier.tree_.value? (для классификации).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...