Вы можете получить эти данные из древовидной структуры:
import sklearn
import numpy as np
import graphviz
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.datasets import make_regression
# Generate a simple dataset
X, y = make_regression(n_features=2, n_informative=2, random_state=0)
clf = DecisionTreeRegressor(random_state=0, max_depth=2)
clf.fit(X, y)
# Visualize the tree
graphviz.Source(sklearn.tree.export_graphviz(clf)).view()
>>> clf.predict(X[:5])
0 184.005667
1 53.017289
2 184.005667
3 -20.603498
4 -97.414461
Если вы позвоните clf.apply(X)
, вы получитеИдентификатор узла, к которому относится экземпляр:
array([6, 5, 6, 3, 2, 5, 5, 3, 6, ... 5, 5, 6, 3, 2, 2, 5, 2, 2], dtype=int64)
Объединение его с целевой переменной:
df = pd.DataFrame(np.vstack([y, clf.apply(X)]), index=['y','node_id']).T
y node_id
0 190.370562 6.0
1 13.339570 5.0
2 141.772669 6.0
3 -3.069627 3.0
4 -26.062465 2.0
5 54.922541 5.0
6 25.952881 5.0
...
Теперь, если вы выполните групповую операцию на node_id
, за которой следует, что вы получитете же значения, что и clf.predict(X)
>>> df.groupby('node_id').mean()
y
node_id
2.0 -97.414461
3.0 -20.603498
5.0 53.017289
6.0 184.005667
Какие value
s листьев в нашем дереве:
>>> clf.tree_.value[6]
array([[184.00566679]])
Чтобы получить идентификаторы узла для нового набора данных, вам нужно вызвать
clf.decision_path(X[:5]).toarray()
, который показывает вам массив, подобный этому
array([[1, 0, 0, 0, 1, 0, 1],
[1, 0, 0, 0, 1, 1, 0],
[1, 0, 0, 0, 1, 0, 1],
[1, 1, 0, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]], dtype=int64)
, где вам нужно получить последний ненулевой элемент (т.е. лист)
>>> pd.DataFrame(clf.decision_path(X[:5]).toarray()).apply(lambda x:x.nonzero()[0].max(), axis=1)
0 6
1 5
2 6
3 3
4 2
dtype: int64
Так что если бы вместо предсказания среднего значения вы хотели предсказать медиану, вы бы сделали
>>> pd.DataFrame(clf.decision_path(X[:5]).toarray()).apply(lambda x: x.nonzero()[0].max(
), axis=1).to_frame(name='node_id').join(df.groupby('node_id').median(), on='node_id')['y']
0 181.381106
1 54.053170
2 181.381106
3 -28.591188
4 -93.891889