partykit: Измените блок-графики терминальных узлов на гистограммы, которые показывают среднее и стандартное отклонение - PullRequest
0 голосов
/ 20 декабря 2018

Я создал дерево регрессии в R. Вот код:

tree <- rpart(y~., method="anova", minsplit=20, minbucket=20, maxdepth=3, data=foo)

plot(as.party(tree), terminal_panel = node_boxplot)

Вместо коробочного графика, который показывает медиану и IQR, я хотел бы, чтобы мои терминальные узлы включали гистограмму, содин столбец для отображения средних и стандартных ошибок отклонения.Я протестировал все опции Terminal_panel, но ни один из них не сделал этого.Есть предложения?

1 Ответ

0 голосов
/ 21 декабря 2018

Насколько мне известно, такой функции панели не существовало, поэтому я ее написал.См. node_dynamite() ниже.С этим вы можете сделать:

library("rpart")
library("partykit")
p <- as.party(rpart(dist ~ speed, data = cars))
plot(p, terminal_panel = node_dynamite)

node_dynamite

node_dynamite <- function(obj, factor = 1,
                          col = "black",
                          fill = "lightgray",
                          bg = "white",
                          width = 0.5,
                          yscale = NULL,
                          ylines = 3,
                          cex = 0.5,
                          id = TRUE,
                          mainlab = NULL, 
                          gp = gpar())
{
    ## observed data/weights and tree fit
    y <- obj$fitted[["(response)"]]
    stopifnot(is.numeric(y))
    g <- obj$fitted[["(fitted)"]]
    w <- obj$fitted[["(weights)"]]
    if(is.null(w)) w <- rep(1, length(y))

    ## (weighted) means and standard deviations by node
    n <- tapply(w, g, sum)
    m <- tapply(y * w, g, sum)/n
    s <- sqrt(tapply((y - m[factor(g)])^2 * w, g, sum)/(n - 1))

    if (is.null(yscale)) 
        yscale <- c(min(c(0, (m - factor * s) * 1.1)), max(c(0, (m + factor * s) * 1.1)))

    ### panel function for boxplots in nodes
    rval <- function(node) {

        ## extract data
        nid <- id_node(node)
        mid <- m[as.character(nid)]
        sid <- s[as.character(nid)]
        wid <- n[as.character(nid)]

        top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3,
                           widths = unit(c(ylines, 1, 1), 
                                         c("lines", "null", "lines")),  
                           heights = unit(c(1, 1), c("lines", "null"))),
                           width = unit(1, "npc"), 
                           height = unit(1, "npc") - unit(2, "lines"),
               name = paste("node_dynamite", nid, sep = ""),
               gp = gp)

        pushViewport(top_vp)
        grid.rect(gp = gpar(fill = bg, col = 0))

        ## main title
        top <- viewport(layout.pos.col=2, layout.pos.row=1)
        pushViewport(top)
        if (is.null(mainlab)) { 
      mainlab <- if(id) {
        function(id, nobs) sprintf("Node %s (n = %s)", id, nobs)
      } else {
        function(id, nobs) sprintf("n = %s", nobs)
      }
        }
    if (is.function(mainlab)) {
          mainlab <- mainlab(names(obj)[nid], wid)
    }
        grid.text(mainlab)
        popViewport()

        plot <- viewport(layout.pos.col = 2, layout.pos.row = 2,
                         xscale = c(0, 1), yscale = yscale,
             name = paste0("node_dynamite", nid, "plot"),
             clip = FALSE)

        pushViewport(plot)

        grid.yaxis()
        grid.rect(gp = gpar(fill = "transparent"))
    grid.clip()

    xl <- 0.5 - width/8
    xr <- 0.5 + width/8

        ## box & whiskers
        grid.rect(unit(0.5, "npc"), unit(0, "native"), 
                  width = unit(width, "npc"), height = unit(mid, "native"),
                  just = c("center", "bottom"), 
                  gp = gpar(col = col, fill = fill))
        grid.lines(unit(0.5, "npc"), 
                   unit(mid + c(-1, 1) * factor * sid, "native"), gp = gpar(col = col))
        grid.lines(unit(c(xl, xr), "npc"), unit(mid - factor * sid, "native"), 
                   gp = gpar(col = col))
        grid.lines(unit(c(xl, xr), "npc"), unit(mid + factor * sid, "native"), 
                   gp = gpar(col = col))

        upViewport(2)
    }

    return(rval)
}
class(node_dynamite) <- "grapcon_generator"
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...