как правила извлечения правил случайного леса в питоне - PullRequest
0 голосов
/ 30 мая 2018

У меня есть один вопрос, хотя.Я слышал от кого-то, что в R вы можете использовать дополнительные пакеты для извлечения правил принятия решений, реализованных в RF, я пытаюсь гуглить то же самое в python, но без удачи, если есть какая-либо помощь в том, как этого добиться.заранее спасибо!

1 Ответ

0 голосов
/ 23 июня 2018

Предполагая, что вы используете sklearn RandomForestClassifier, вы можете найти отдельные деревья решений как .estimators_.Каждое дерево хранит узлы принятия решений в виде нескольких массивов NumPy в tree_.

. Вот пример кода, который просто печатает каждый узел в порядке массива.В типичном приложении вместо этого следуют дети.

import numpy
from sklearn.model_selection import train_test_split
from sklearn import metrics, datasets, ensemble

def print_decision_rules(rf):

    for tree_idx, est in enumerate(rf.estimators_):
        tree = est.tree_
        assert tree.value.shape[1] == 1 # no support for multi-output

        print('TREE: {}'.format(tree_idx))

        iterator = enumerate(zip(tree.children_left, tree.children_right, tree.feature, tree.threshold, tree.value))
        for node_idx, data in iterator:
            left, right, feature, th, value = data

            # left: index of left child (if any)
            # right: index of right child (if any)
            # feature: index of the feature to check
            # th: the threshold to compare against
            # value: values associated with classes            

            # for classifier, value is 0 except the index of the class to return
            class_idx = numpy.argmax(value[0])

            if left == -1 and right == -1:
                print('{} LEAF: return class={}'.format(node_idx, class_idx))
            else:
                print('{} NODE: if feature[{}] < {} then next={} else next={}'.format(node_idx, feature, th, left, right))    


digits = datasets.load_digits()
Xtrain, Xtest, ytrain, ytest = train_test_split(digits.data, digits.target)
estimator = ensemble.RandomForestClassifier(n_estimators=3, max_depth=2)
estimator.fit(Xtrain, ytrain)

print_decision_rules(estimator)

Пример outout:

TREE: 0
0 NODE: if feature[33] < 2.5 then next=1 else next=4
1 NODE: if feature[38] < 0.5 then next=2 else next=3
2 LEAF: return class=2
3 LEAF: return class=9
4 NODE: if feature[50] < 8.5 then next=5 else next=6
5 LEAF: return class=4
6 LEAF: return class=0
...

Мы используем нечто подобное в emtrees для компиляции кода Random Forest в C.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...