Дерево решений, использующее rpart для создания диаграммы Санки - PullRequest
0 голосов
/ 06 сентября 2018

Я могу создать дерево с помощью Rpart, используя набор данных Kyphosis, который является частью базы R:

fit <- rpart(Kyphosis ~ Age + Number + Start,
         method="class", data=kyphosis)
printcp(fit)
plot(fit, uniform=TRUE,main="Classification Tree for Kyphosis")
text(fit, use.n=TRUE, all=TRUE, cex=.8)

Вот как выглядит дерево: enter image description here

Теперь, чтобы лучше представить дерево, я хочу использовать диаграмму Санки, используя график. Чтобы создать диаграмму Санки на графике, нужно сделать следующее:

library(plotly)
nodes=c("Start>=8.5","Start>-14.5","absent",
                   "Age<55","absent","Age>=111","absent","present","present")
p <- plot_ly(
  type = "sankey",
  orientation = "h",      
  node = list(
    label = nodes,
    pad = 10,
    thickness = 20,
    line = list(
      color = "black",
      width = 0.5
    )
  ),

  link = list(
    source = c(0,1,1,3,3,5,5,0),
    target = c(1,2,3,4,5,6,7,8),
    value =  c(1,1,1,1,1,1,1,1)
  )
) %>% 
  layout(
    title = "Desicion Tree",
    font = list(
      size = 10
    )
  )
p

Это создает диаграмму Санки, соответствующую дереву (жестко запрограммированному). Три необходимых вектора: «источник», «цель», «значение» выглядят следующим образом:

Схема Санки с жестким кодом:

enter image description here

Моя проблема в использовании объекта rpart 'fit'. Мне кажется, я не могу легко получить вектор для получения требуемых векторов 'source', 'target' и 'value' для графика.

fit $ frame и fit $ split содержит некоторую информацию, но их сложно объединить или использовать вместе. Использование функции печати для объекта подгонки дает необходимую информацию, но я не хочу редактировать текст, чтобы получить ее.

print(fit)

Выход:

1) root 81 17 absent (0.79012346 0.20987654)  
   2) Start>=8.5 62  6 absent (0.90322581 0.09677419)  
     4) Start>=14.5 29  0 absent (1.00000000 0.00000000) *
     5) Start< 14.5 33  6 absent (0.81818182 0.18181818)  
      10) Age< 55 12  0 absent (1.00000000 0.00000000) *
      11) Age>=55 21  6 absent (0.71428571 0.28571429)  
        22) Age>=111 14  2 absent (0.85714286 0.14285714) *
        23) Age< 111 7  3 present (0.42857143 0.57142857) *
   3) Start< 8.5 19  8 present (0.42105263 0.57894737) *

Так есть ли простой способ использовать объект rpart для получения этих 3 векторов для построения графика Санки? Этот график будет использоваться в веб-приложении, поэтому его необходимо использовать специально, поскольку у нас уже есть javascript, который ему соответствует, и его легко можно использовать повторно для применения к различным наборам данных.

Ответы [ 2 ]

0 голосов
/ 14 сентября 2018

У меня есть временное решение на данный момент. Мне просто не нравится загрузка в дополнительную библиотеку. Но вот оно: Подгонка модели для набора данных Iris:

fit <- rpart(Species~Sepal.Length +Sepal.Width   ,
         method="class", data=iris)

printcp(fit)
plot(fit, uniform=TRUE, 
     main="Classification Tree for IRIS")
text(fit, use.n=TRUE, all=TRUE, cex=.8)

Способ получения имен узлов был следующим:

treeFrame=fit$frame
nodes=sapply(row.names(treeFrame),function(x) unlist(rpart::path.rpart(fit,x))
        [length(unlist(rpart::path.rpart(fit,x)))])

Но в решении @BigDataScientist есть лучший способ:

treeFrame=fit$frame
isLeave <- treeFrame$var == "<leaf>"
nodes <- rep(NA, length(isLeave))
ylevel <- attr(fit, "ylevels")
nodes[isLeave] <- ylevel[treeFrame$yval][isLeave]
nodes[!isLeave] <- labels(fit)[-1][!isLeave[-length(isLeave)]]

Теперь получить источник и цель все еще сложно, но мне помог пакет rpart.utils:

library('rpart.utils')
treeFrame=fit$frame
treeRules=rpart.utils::rpart.rules(fit)

targetPaths=sapply(as.numeric(row.names(treeFrame)),function(x)  
                      strsplit(unlist(treeRules[x]),split=","))

lastStop=  sapply(1:length(targetPaths),function(x) targetPaths[[x]] 
                      [length(targetPaths[[x]])])

oneBefore=  sapply(1:length(targetPaths),function(x) targetPaths[[x]] 
                      [length(targetPaths[[x]])-1])


target=c()
source=c()
values=treeFrame$n
for(i in 2:length(oneBefore))
{
  tmpNode=oneBefore[[i]]
  q=which(lastStop==tmpNode)

  q=ifelse(length(q)==0,1,q)
  source=c(source,q)
  target=c(target,i)

}
source=source-1
target=target-1

Так что мне не нравится использовать дополнительную библиотеку, но, похоже, это работает для различных наборов данных. И лучше использовать способ @BigDataScientist для получения узлов. Но я все еще буду искать лучшие решения. @ BigDataScientist Я думаю, что ваше решение будет работать лучше, возможно, что-то маленькое нужно изменить. Но я пока не очень хорошо понимаю часть «повторений» в вашем коде.

А код для сюжета в конце концов:

 p <- plot_ly(
 type = "sankey",
 orientation = "v",

 node = list(
     label = nodes,
     pad = 15,
     thickness = 20,
     line = list(
     color = "black",
     width = 0.5
     )
 ),

 link = list(
     source = source,
     target = target,
     value=values[-1]

 )
 ) %>% 
 layout(
     title = "Basic Sankey Diagram",
     font = list(
     size = 10
     )
 )
 p
0 голосов
/ 13 сентября 2018

Вот моя попытка:

Из того, что я вижу, задача состоит в том, чтобы генерировать nodes и source переменные.

Пример данных:

fit <- rpart(Kyphosis ~ Age + Number + Start,
             method="class", data=kyphosis)

Генерация nodes:

frame <- fit$frame
isLeave <- frame$var == "<leaf>"
nodes <- rep(NA, length(isLeave))
ylevel <- attr(fit, "ylevels")
nodes[isLeave] <- ylevel[frame$yval][isLeave]
nodes[!isLeave] <- labels(fit)[-1][!isLeave[-length(isLeave)]]

Генерация source:

node <- as.numeric(row.names(frame))
depth <- rpart:::tree.depth(node)
source <- depth[-1] - 1

reps <- rle(source)
tobeAdded <- reps$values[sapply(reps$values, function(val) sum(val >= which(reps$lengths > 1))) > 0]
update <- source %in% tobeAdded
source[update] <- source[update] + sapply(tobeAdded, function(tobeAdd) rep(sum(which(reps$lengths > 1) <= tobeAdd), 2))

Протестировано с:

library(rpart)
fit <- rpart(Kyphosis ~ Age + Number + Start,
             method="class", data=kyphosis)
fit2 <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis,
              parms = list(prior = c(.65,.35), split = "information"))

Как туда добраться:

См .: getS3method("print", "rpart")

...