Дерево рекурсии на python - PullRequest
0 голосов
/ 25 февраля 2020

У меня есть следующий алгоритм дерева, который печатает условия для каждого листа:

def _grow_tree(self, X, y, depth=0):
    # Identify best split
    idx, thr = self._best_split(X, y)

    # Indentation for tree description
    indent = "    " * depth

    indices_left = X.iloc[:, idx] < thr
    X_left = X[indices_left]
    y_left = y_train[X_left.reset_index().loc[:,'id'].values]

    X_right = X[~indices_left]
    y_right = y_train[X_right.reset_index().loc[:,'id'].values]

    self.tree_describe.append(indent +"if x['"+ X.columns[idx] + "'] <= " +\
                             str(thr) + ':')
    # Grow on left side of the tree  
    node.left = self._grow_tree(X_left, y_left, depth + 1)

    self.tree_describe.append(indent +"else: #if x['"+ X.columns[idx] + "'] > " +\
                         str(thr) + ':')
    # Grow on right side of the tree
    node.right = self._grow_tree(X_right, y_right, depth + 1)

    return node

Это дает следующий отпечаток для конкретного случая:

["if x['VAR1'] <= 0.5:",
 "    if x['VAR2'] <= 0.5:",
 "    else: #if x['VAR2'] > 0.5:",
 "else: #if x['VAR1'] > 0.5:",
 "    if x['VAR3'] <= 0.5:",
 "    else: #if x['VAR3'] > 0.5:"]

Как я могу получить следующий вывод?:

["if x['VAR1'] <= 0.5:",
 "    if x['VAR1'] <= 0.5&x['VAR2'] <= 0.5",
 "    else: #if x['VAR1'] <= 0.5&x['VAR2'] > 0.5:",
 "else: #if x['VAR1'] > 0.5:",
 "    if x['VAR1'] > 0.5&x['VAR3'] <= 0.5:",
 "    else: #if x['VAR1'] > 0.5&x['VAR3'] > 0.5:"]

1 Ответ

1 голос
/ 25 февраля 2020

Вы можете ввести в вашу функцию новый аргумент, в котором будет строка с условиями более высокого уровня, которые необходимо добавить к каждому более глубокому условию:

Я бы также предложил использовать .format() для вашего строения:

def _grow_tree(self, X, y, depth=0, descr=""):

    idx, thr = self._best_split(X, y)

    indent = "    " * depth

    cond = "x['{}'] <= {}{}".format(X.columns[idx], thr, descr)
    self.tree_describe.append("{}if {}:".format(indent, cond))

    node.left = self._grow_tree(X_left, y_left, depth + 1, " & " + cond)

    cond = "x['{}'] > {}{}".format(X.columns[idx], thr, descr)
    self.tree_describe.append("{}else: #if {}:".format(indent, cond))

    node.right = self._grow_tree(X_right, y_right, depth + 1, " & " + cond)

    return node
...