Как извлечь правила принятия решений (функции разделены) из модели xgboost в python3? - PullRequest
0 голосов
/ 04 мая 2018

Мне нужно извлечь правила принятия решений из моей приспособленной модели xgboost в python. Я использую версию 0.6a2 библиотеки xgboost, а моя версия на Python - 3.5.2.

Моя конечная цель - использовать эти разбиения для переменных бина (в соответствии с разбиениями).

Я не обнаружил ни одного свойства модели для этой версии, которое может дать мне сплиты.

plot_tree дает мне нечто подобное. Однако это визуализация дерева.

Мне нужно что-то вроде https://stackoverflow.com/a/39772170/4559070 для модели xgboost

Ответы [ 2 ]

0 голосов
/ 14 мая 2018

Это возможно, но не легко. Я бы порекомендовал вам использовать GradientBoostingClassifier из scikit-learn, что аналогично xgboost, но имеет встроенный доступ к построенным деревьям.

Однако, используя xgboost, можно получить текстовое представление модели и затем проанализировать ее:

from sklearn.datasets import load_iris
from xgboost import XGBClassifier
# build a very simple model
X, y = load_iris(return_X_y=True)
model = XGBClassifier(max_depth=2, n_estimators=2)
model.fit(X, y);
# dump it to a text file
model.get_booster().dump_model('xgb_model.txt', with_stats=True)
# read the contents of the file
with open('xgb_model.txt', 'r') as f:
    txt_model = f.read()
print(txt_model)

Будет напечатано текстовое описание 6 деревьев (2 оценки, каждое состоит из 3 деревьев, по одному на класс), которое начинается следующим образом:

booster[0]:
0:[f2<2.45] yes=1,no=2,missing=1,gain=72.2968,cover=66.6667
    1:leaf=0.143541,cover=22.2222
    2:leaf=-0.0733496,cover=44.4444
booster[1]:
0:[f2<2.45] yes=1,no=2,missing=1,gain=18.0742,cover=66.6667
    1:leaf=-0.0717703,cover=22.2222
    2:[f3<1.75] yes=3,no=4,missing=3,gain=41.9078,cover=44.4444
        3:leaf=0.124,cover=24
        4:leaf=-0.0668394,cover=20.4444
...

Теперь вы можете, например, извлечь все разбиения из этого описания:

import re
# trying to extract all patterns like "[f2<2.45]"
splits = re.findall('\[f([0-9]+)<([0-9]+.[0-9]+)\]', txt_model)
splits

Будет выведен список кортежей (feature_id, split_value), например

[('2', '2.45'),
 ('2', '2.45'),
 ('3', '1.75'),
 ('3', '1.65'),
 ('2', '4.95'),
 ('2', '2.45'),
 ('2', '2.45'),
 ('3', '1.75'),
 ('3', '1.65'),
 ('2', '4.95')]

Вы можете в дальнейшем обработать этот список по своему желанию.

0 голосов
/ 12 мая 2018

Вам нужно знать имя вашего дерева, и после этого вы можете вставить его в свой код.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...