Обход дерева решений sklearn - PullRequest
1 голос
/ 14 апреля 2020

Как мне сделать первый обход в обход дерева решений sklearn?

В своем коде я пробовал библиотеку sklearn.tree_ и использовал различные функции, такие как tree_.feature и tree_.threshold, чтобы понять структура дерева. Но эти функции выполняют обход дерева dfs, если я хочу сделать bfs, как мне это сделать?

Предположим,

clf1 = DecisionTreeClassifier( max_depth = 2 )
clf1 = clf1.fit(x_train, y_train)

, это мой классификатор, и полученное дерево решений

enter image description here

Затем я прошел по дереву, используя следующую функцию

def encoding(clf, features):
l1 = list()
l2 = list()

for i in range(len(clf.tree_.feature)):
    if(clf.tree_.feature[i]>=0):
        l1.append( features[clf.tree_.feature[i]])
        l2.append(clf.tree_.threshold[i])
    else:
        l1.append(None)
        print(np.max(clf.tree_.value))
        l2.append(np.argmax(clf.tree_.value[i]))

l = [l1 , l2]

return np.array(l)

, и был получен результат

array([['address', 'age', None, None, 'age', None, None],
       [0.5, 17.5, 2, 1, 15.5, 1, 1]], dtype=object)

где 1-й массив - это особенность узла, или, если он является листовым, тогда он помечается как none, а 2-й массив - это порог для узла-объекта, а для узла класса - это класс, но это обход dfs дерева, я хочу сделать обход bfs, что я должен делать?

Поскольку я новичок в переполнении стека, пожалуйста, предложите, как улучшить описание вопроса и какую дополнительную информацию я должен добавить, если таковая имеется, чтобы объяснить мою проблему дальше.

X_train (sample) X_train

y_train (образец) y_train

1 Ответ

0 голосов
/ 14 апреля 2020

Это должно сделать это:

from collections import deque

tree = clf.tree_

stack = deque()
stack.append(0)  # push tree root to stack

while stack:
    current_node = stack.popleft()

    # do whatever you want with current node
    # ...

    left_child = tree.children_left[current_node]
    if left_child >= 0:
        stack.append(left_child)

    right_child = tree.children_right[current_node]
    if right_child >= 0:
        stack.append(right_child)

Это использует deque, чтобы сохранить стек узлов для последующей обработки. Поскольку мы удаляем элементы слева и добавляем их справа, это должно представлять обход в ширину.


Для реального использования я предлагаю вам превратить это в генератор:

from collections import deque

def breadth_first_traversal(tree):
    stack = deque()
    stack.append(0)

    while stack:
        current_node = stack.popleft()

        yield current_node

        left_child = tree.children_left[current_node]
        if left_child >= 0:
            stack.append(left_child)

        right_child = tree.children_right[current_node]
        if right_child >= 0:
            stack.append(right_child)

Тогда вам нужно только минимальные изменения вашей исходной функции:

def encoding(clf, features):
    l1 = list()
    l2 = list()

    for i in breadth_first_traversal(clf.tree_):
        if(clf.tree_.feature[i]>=0):
            l1.append( features[clf.tree_.feature[i]])
            l2.append(clf.tree_.threshold[i])
        else:
            l1.append(None)
            print(np.max(clf.tree_.value))
            l2.append(np.argmax(clf.tree_.value[i]))

    l = [l1 , l2]

    return np.array(l)
...