Вы можете использовать этот пример, чтобы построить то, что вам нужно.
Попробуйте этот пример:
from sklearn.tree import plot_tree
from sklearn.datasets import make_classification
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
import numpy as np
X, y = make_classification(n_samples=1000,n_features=2,
n_redundant=0, n_clusters_per_class=1, random_state=4)
labels = ['type_A', 'type_B']
clf = DecisionTreeClassifier(max_depth=3).fit(X, y)
# Parameters
n_classes = 2
plot_colors = "ryb"
plot_step = 0.02
# Plot the decision boundary
plt.figure()
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
np.arange(y_min, y_max, plot_step))
plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu)
plt.xlabel('feature_1')
plt.ylabel('feature_2')
# Plot the training points
for i, color in zip(range(n_classes), plot_colors):
idx = np.where(y == i)
plt.scatter(X[idx, 0], X[idx, 1], c=color, label=labels,
cmap=plt.cm.RdYlBu, edgecolor='black', s=15)
f, ax = plt.subplots(figsize=(15, 7))
plot_tree(clf, filled=True, feature_names=['feature_1', 'feature_2'],
ax=ax, fontsize=6,
class_names=labels)
plt.show()
Обновление:
Для проблемы регрессии
from sklearn.tree import plot_tree
from sklearn.datasets import make_regression
from sklearn.tree import DecisionTreeRegressor
from matplotlib import pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
X, y = make_regression(n_samples=1000, n_features=2,n_informative=2,
random_state=0)
reg = DecisionTreeRegressor(max_depth=4).fit(X, y)
# Parameters
plot_colors = "ryb"
plot_step = 0.02
# Plot the decision boundary
f, axes =plt.subplots(ncols=2,figsize=(30, 7))
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
np.arange(y_min, y_max, plot_step))
plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
Z = reg.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Blues)
plt.xlabel('feature_1')
plt.ylabel('feature_2')
axes[1].scatter(X[:, 0], X[:, 1], c=y,
cmap='Oranges', edgecolor='black', s=15)
plot_tree(reg, filled=True, feature_names=['feature_1', 'feature_2'],
ax=axes[0], fontsize=3,
class_names='Target')
plt.show()