Получить терминальные узлы данных в дереве библиотеки - PullRequest
1 голос
/ 31 марта 2020

Я пытаюсь построить или с некоторыми тестовыми данными. Моя цель - узнать, сколько терминальных узлов / листьев у моего дерева и в каких терминальных узлах появляются новые данные.

Я использую библиотеку , потому что у нее есть возможность получить узел, в который попадает каждая точка данных, используя predict(tree.model, data=df, type="where")

Я создал несколько примеров данных и попробовал это. Но похоже, что predict не только выводит терминальные узлы. При запуске моего кода predict(...) имеет факторы 3 5 6 8 9. Но дерево выглядит как

 1) root 700 969.900 1 ( 0.487143 0.512857 )  
   2) B < 0.339751 346 104.300 0 ( 0.965318 0.034682 )  
     4) A < 0.747861 331  13.600 0 ( 0.996979 0.003021 ) *
     5) A > 0.747861 15  17.400 1 ( 0.266667 0.733333 )  
      10) B < 0.139725 5   5.004 0 ( 0.800000 0.200000 ) *
      11) B > 0.139725 10   0.000 1 ( 0.000000 1.000000 ) *
   3) B > 0.339751 354  68.790 1 ( 0.019774 0.980226 )  
     6) A < 0.157866 8   6.028 0 ( 0.875000 0.125000 ) *
     7) A > 0.157866 346   0.000 1 ( 0.000000 1.000000 ) *

(«*» обозначает конечные узлы).

Есть ли возможность получать только терминальные узлы? Желательно в древовидной библиотеке.

Вот мой полный пример кода, основная часть только создает примеры данных.

library(ggplot2)
library(hrbrthemes)

#generating some data to test######################################
    set.seed(42)
    #category A
    x1s = rchisq(500, 5, ncp = 0)
    y1s = 1/x1s +0.1*rchisq(500, 8, ncp = 0)
    x1s = (x1s-min(x1s))/max(x1s)
    y1s = (y1s-min(y1s))/max(y1s)
    #category B
    x2s = 15-rchisq(500, 5, ncp = 0)
    y2s = 5-(2.5 -1/400*(x2s-15)^2 +0.1*rchisq(500, 8, ncp = 0))
    x2s = (x2s-min(x2s))/max(x2s)
    y2s = (y2s-min(y2s))/max(y2s)

    xs = c(x1s, x2s)
    ys = c(y1s, y2s)
    type = c(0*(1:500), 0*(1:500)+1)
    df = data.frame(type, xs, ys)
    names(df) = c("category","A","B")
    df$category = factor(df$category)



#plot the generated data##########################################
ggplot(df, aes(x=A, y=B, color=category)) + geom_point(shape=1)

#seperate in training and test data
alpha = 0.7
inTrain = sample(1:nrow(df), alpha*nrow(df))
train.set = df[inTrain,]
test.set = df[-inTrain, ]

####################################################################
#use tree to predict category
library(tree)
tree.model = tree(category ~ A + B, data = train.set)
factor(predict(tree.model, data = test.set, type="where"))
tree.model
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...