Я пытаюсь реализовать простое дерево решений для набора данных. Я использую следующие импорты:
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn import tree
import pydotplus
и генерирую дерево, используя:
X_train = data.drop('var', axis=1)
y_train = data['var']
dt = tree.DecisionTreeClassifier()
dt.fit(X_train, y_train)
Y_pred_4 = dt.predict(data_test)
acc_dt = dt.score(X_train, y_train) * 100
acc_dt
Теперь я пытаюсь построить дерево решений, созданное с помощью:
fig = dt.fit(X_train,y_train)
tree.plot_tree(fig)
plt.show()
Но я получаю следующую ошибку при построении
AttributeError: модуль 'sklearn.tree' не имеет атрибута 'plot_tree'
Я не уверен, что не так в коде. Может кто посоветует?
Спасибо