Накапливать условия во время рекурсии по дереву классификации - PullRequest
0 голосов
/ 17 января 2020

У меня есть следующая функция, которая генерирует код из дерева классификации обучения sci-kit:

def mxTreeToCode(tree, feature_names, mx_name = 'mxTree', rm_file = False):

    # Remove pre-existent file
    if rm_file:
        import os
        try:
            os.remove('./tree.py')
        except OSError:
            pass

    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    file = open('tree.py', 'a')
    file.write('def ' + mx_name + '(x):'+ '\n') 
    #col_name = ''
    def recurse(node, depth):
        global col_name
        indent = "    " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]

            file.write(indent +"if x['"+ name + "'] <= " + str(threshold) + ':' + '\n')
            col_name += "'"+name + '_' + '<=' + str(threshold) +"'"

            recurse(tree_.children_left[node], depth + 1)


            file.write(indent + "else: # if x['"+ name +"'] > " + str(threshold) + '\n')
            col_name += "'"+name + '_' + '>' + str(threshold) +"'"

            recurse(tree_.children_right[node], depth + 1)


        else:
            file.write(indent + 'return '+str(col_name) + '\n')
            #print(col_name)
            col_name = ""

    recurse(0, 1)
    file.close()

При этом я получаю следующий вывод в файле tree.py для данного дерева классификации:

def mxTree(x):
    if x['V1'] <= 0.5:
        if x['V2'] <= 0.5:
            return 'V1_<=0.5''V2_<=0.5'
        else: # if x['V2'] > 0.5
            return 'V2_>0.5'
    else: # if x['V1'] > 0.5
        return 'V1_>0.5'

Хотя я могу накапливать условия на стороне IF и возвращать сложение условий, я не могу выполнить накопление, когда следуют IF и ELSE (левая / правая сторона узла дерева):

def mxTree(x):
    if x['V1'] <= 0.5:
        if x['V2'] <= 0.5:
            return 'V1_<=0.5''V2_<=0.5'
        else: # if x['V2'] > 0.5
            return 'V1_<=0.5''V2_>0.5' # 'V1<=0.5' must be added
    else: # if x['V1'] > 0.5
        return 'V1_>0.5'

Буду признателен за любые предложения.

1 Ответ

0 голосов
/ 17 января 2020

Так как левая / правая сторона каждого узла рекурсивны одновременно, я просто создал дополнительную переменную, которая сохраняет выходные данные для каждой стороны. Наконец, я объединяю переменную col_name:

col_name = ""
names_list={}
def mxTreeToCode(tree, feature_names, mx_name = 'mxTree', rm_file = False):

    # Remove pre-existent file
    if rm_file:
        import os
        try:
            os.remove('./tree.py')
        except OSError:
            pass

    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    file = open('tree.py', 'a')
    file.write('def ' + mx_name + '(x):'+ '\n') 

    def recurse(node, depth):
        global col_name, names_list
        indent = "    " * depth
        names_list[node] = col_name
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]

            file.write(indent +"if x['"+ name + "'] <= " + str(threshold) + ':' + '\n')
            col_name += "'"+name + '_' + '<=' + str(threshold) +"'"

            recurse(tree_.children_left[node], depth + 1)


            file.write(indent + "else: # if x['"+ name +"'] > " + str(threshold) + '\n')
            col_name += names_list[node]
            col_name += "'"+name + '_' + '>' + str(threshold) +"'"

            recurse(tree_.children_right[node], depth + 1)


        else:
            file.write(indent + 'return '+str(col_name) + '\n')
            col_name = ""

    recurse(0, 1)
    file.close()

Интересно, есть ли другие рабочие подходы.

...