rpart создает таблицу, которая указывает, принадлежит ли наблюдение узлу или нет - PullRequest
0 голосов
/ 18 января 2019

На следующем рисунке показано, что я хочу сделать:

  1. Вырастите дерево с rpart для некоторого набора данных
  2. Создать таблицу с одной строкой на наблюдение в исходном наборе данных и одним столбцом на узел в дереве плюс идентификатор. Столбцы узлов должны принимать значение 1, если наблюдение принадлежит этому узлу, и ноль в противном случае.

enter image description here

Это код, который я написал:

library(rpart)
  library(rattle)
  data <- kyphosis
  fit <- rpart(Age ~ Number + Start, data = kyphosis)
  fancyRpartPlot(fit)

  nodeNumbers <- as.numeric(rownames(fit$frame))

  paths <- path.rpart(fit, nodeNumbers)

  for(i in 1:length(nodeNumbers)){
    nodeNumber <- nodeNumbers[i]
    data[,paste0('gp', nodeNumber)] <- NA
    path <- paths[[i]]
    if(length(path) == 1) # i.e. we're at the root
      data[,paste0('gp', nodeNumber)] <- 1 else
        print('help')
  }
  data

Есть ли какой-нибудь пакет, чтобы делать то, что мне нужно? Единственный способ, которым я могу думать об этом, - это использовать магию регулярных выражений для объекта paths. Я думаю / надеюсь, что есть более простой способ сделать это.

1 Ответ

0 голосов
/ 19 января 2019

Есть ли пакет, чтобы сделать то, что мне нужно?

AFAIK, нет, но эта работа в rpart версии 4.1.13

# function to get the binary matrix OP wants given the leaf index
get_nodes <- function(object, where){
  rn <- row.names(object$frame)
  edges <- descendants(as.numeric(rn))
  o <- t(edges)[where, , drop = FALSE]
  colnames(o) <- paste0("GP", rn)
  o
}
environment(get_nodes) <- environment(rpart)

# use function 
nodes <- get_nodes(fit, fit$where)
head(nodes, 9)
#R       GP1   GP2   GP3   GP6   GP7  GP14  GP15
#R [1,] TRUE FALSE  TRUE FALSE  TRUE  TRUE FALSE
#R [2,] TRUE FALSE  TRUE FALSE  TRUE FALSE  TRUE
#R [3,] TRUE FALSE  TRUE FALSE  TRUE  TRUE FALSE
#R [4,] TRUE  TRUE FALSE FALSE FALSE FALSE FALSE
#R [5,] TRUE FALSE  TRUE FALSE  TRUE FALSE  TRUE
#R [6,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
#R [7,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
#R [8,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
#R [9,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE

# compare with
head(data, 9)
#R   Kyphosis Age Number Start
#R 1   absent  71      3     5
#R 2   absent 158      3    14
#R 3  present 128      4     5
#R 4   absent   2      5     1
#R 5   absent   1      4    15
#R 6   absent   1      2    16
#R 7   absent  61      2    17
#R 8   absent  37      3    16
#R 9   absent 113      2    16

Вот полный код, который соответствует модели, создает функцию, которая может получить конечный лист для нового набора данных, а также создает и использует вышеуказанную функцию

# do as OP
library(rpart)
library(rattle)
data <- kyphosis
fit <- rpart(Age ~ Number + Start, data = kyphosis)
fancyRpartPlot(fit)

enter image description here

# function that gives us the leaf index
get_where <- function(object, newdata, na.action = na.pass){
  if (is.null(attr(newdata, "terms"))) {
    Terms <- delete.response(object$terms)
    newdata <- model.frame(Terms, newdata, na.action = na.action, 
                           xlev = attr(object, "xlevels"))
    if (!is.null(cl <- attr(Terms, "dataClasses"))) 
      .checkMFClasses(cl, newdata, TRUE)
  }
  pred.rpart(object, rpart.matrix(newdata))
}
environment(get_where) <- environment(rpart)

# check that we get the correct value
where <- get_where(fit, data)
stopifnot(isTRUE(all.equal(
  fit$frame$yval[where], unname(predict(fit, newdata = data)))))

# function to get the binary matrix OP wants given the leaf index
get_nodes <- function(object, where){
  rn <- row.names(object$frame)
  edges <- descendants(as.numeric(rn))
  o <- t(edges)[where, , drop = FALSE]
  colnames(o) <- paste0("GP", rn)
  o
}
environment(get_nodes) <- environment(rpart)

# use function 
nodes <- get_nodes(fit, where)
head(nodes, 9)
#R       GP1   GP2   GP3   GP6   GP7  GP14  GP15
#R [1,] TRUE FALSE  TRUE FALSE  TRUE  TRUE FALSE
#R [2,] TRUE FALSE  TRUE FALSE  TRUE FALSE  TRUE
#R [3,] TRUE FALSE  TRUE FALSE  TRUE  TRUE FALSE
#R [4,] TRUE  TRUE FALSE FALSE FALSE FALSE FALSE
#R [5,] TRUE FALSE  TRUE FALSE  TRUE FALSE  TRUE
#R [6,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
#R [7,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
#R [8,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
#R [9,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE

# compare with
head(data, 9)
#R   Kyphosis Age Number Start
#R 1   absent  71      3     5
#R 2   absent 158      3    14
#R 3  present 128      4     5
#R 4   absent   2      5     1
#R 5   absent   1      4    15
#R 6   absent   1      2    16
#R 7   absent  61      2    17
#R 8   absent  37      3    16
#R 9   absent 113      2    16

Код от rpart:::predict.rpart и rpart::path.rpart. Вы можете, конечно, объединить функции get_where и get_nodes, если хотите.

...