Я пытаюсь извлечь правила принятия решений для прогнозирования терминальных узлов и распечатать код, который будет использовать массивы pandas numpy для прогнозирования номеров терминальных узлов.Я нашел решение, которое может использовать правила ( Как извлечь правила принятия решений из дерева решений scikit-learn? ), но я не уверен, как его расширить, чтобы получить то, что мне нужно.Ссылка на решение имеет много ответов.Вот тот, на которого я ссылаюсь, и описание вопроса.
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})
df
# create decision tree
dt = DecisionTreeClassifier(random_state=0, max_depth=5, min_samples_leaf=1)
dt.fit(df.loc[:,('col1','col2')], df.dv)
#This function first starts with the nodes (identified by -1 in the child arrays) and then recursively finds the parents.
#I call this a node's 'lineage'. Along the way, I grab the values I need to create if/then/else SAS logic:
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
for child in idx:
for node in recurse(left, right, child):
print (node)
get_lineage(dt, df.columns)
когда вы запустите код, он предоставит следующее:
(0, 'l', 3.5, 'col2')
1
(0, 'r', 3.5, 'col2')
(2, 'l', 1.5, 'col1')
3
(0, 'r', 3.5, 'col2')
(2, 'r', 1.5, 'col1')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 3.5, 'col2')
(2, 'r', 1.5, 'col1')
(4, 'r', 2.5, 'col1')
6
Как мне развернуть его, чтобы напечатать что-то вроде этого:
df['Terminal_Node_Num']=np.where(df.loc[:,'col2']<=3.5,1,0)
df['Terminal_Node_Num']=np.where(((df.loc[:,'col2']>3.5) & (df.loc[:,'col1']
<=1.5)), 3, df['Terminal_Node_Num'])
df['Terminal_Node_Num']=np.where(((df.loc[:,'col2']>3.5) &
(df.loc[:,'col1']>1.5) & (df.loc[:,'col1']<=2.5)), 5,
df['Terminal_Node_Num'])
df['Terminal_Node_Num']=np.where(((df.loc[:,'col2']>3.5)`enter code here`(df.loc[:,'col1']>1.5) & (df.loc[:,'col1']>2.5)), 6, df['Terminal_Node_Num'])