Можно ли изменить метку узла в дереве решений sklearn? - PullRequest
0 голосов
/ 18 июня 2019

Я строю модель дерева решений, используя scikit-learn, и после того, как хочу переписать некоторые листы.По сути, я хочу изменить метку конкретных узлов листьев.

Я перебираю листья и на основе tree.DecisionTreeClassifier.tree_ я могу получить tree_.value для вычисления метки узла.Я получил это от здесь .У меня вопрос, могу ли я и как заставить изменение метки для узла дерева решений?

Пока я пытался вручную изменить значения в tree_.value

from sklearn import tree
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np

df = pd.read_csv("voting.csv", header=0)
y = pd.DataFrame(df.target)
feature_names = []
for col in df.columns:
    if col != 'target':
        feature_names.append(col)

y = df.target
df = df.drop("target", 1)

thr = 0.9
X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2)
clf = tree.DecisionTreeClassifier(min_samples_leaf=3)
clf.fit(X_train, y_train)

node_count = clf.tree_.node_count
class_label = 0
for index in range(node_count):
    # check if it is a leaf
    if clf.tree_.children_right[index] == -1 and clf.tree_.children_left[index] == -1:
    # number of samples in the leaf (correctly classified and misclassified)
    print("Values: ", clf.tree_.value[index])
    # Finding node label
    node_label = clf.classes_[np.argmax(clf.tree_.value[index])]
    values = clf.tree_.value[index]
    correct_samples = values[0][node_label]
    misclassified_samples = np.sum(clf.tree_.value[index]) - correct_samples
    # Change the label if number of misclassified samples is more than 0
    if misclassified_samples > 0 and node_label != class_label:
        clf.tree_.value[index][class_label] = clf.tree_.value[index][class_label] + correct_samples
        print("New values: ", clf.tree_.value[index])

Но это приводит к изменению обоих значений даже для правильно классифицированных.И тогда метка узла остается прежней.Например, до операции: Values: [[1. 2.]], а после операции: New values: [[3. 4.]]

Спасибо!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...