Пользовательский Rpart, обращающийся к подмножеству данных на каждом узле во время рекурсии ... не после построения дерева - PullRequest
0 голосов
/ 21 апреля 2020

Я изучаю написание пользовательской функции разделения в Rpart, где мне нужен доступ ко всему поднабору данных, оцениваемых в данном узле внутри функции разделения (в соответствии с пользовательскими рекомендациями по rpart, стр. 4-6 https://cran.r-project.org/web/packages/rpart/vignettes/usercode.pdf ). Идея состоит в том, чтобы оценить комбинации разбиений по другим ковариатам (скажем, x2 и x3) в дополнение к тому, что ковариата (скажем, x1) вызывается для любого конкретного узла. Из того, что я могу сказать, функция расщепления принимает аргументы (среди прочих) y и x, где y - это подмножество ответов в конкретном узле, а x - соответствующий ковариат, оцениваемый в этом узле, но я не вижу способа доступа остальная часть набора данных при этом вызове.

Прикрепленный код основан на поддельном наборе данных, который я составил, и это не попытка найти решение, а больше я исследую, какие объекты доступны для меня в рамках каждый рекурсивный шаг, так что я могу изменить пользовательскую функцию разделения, чтобы создать то, что я хочу, не красиво, но это делает работу. Я скопировал пользовательский пример из вышеупомянутой ссылки и изменил stemp с помощью print (), чтобы посмотреть, что я смогу сделать. Я попытался напечатать имена строк векторов y и x, надеясь, что индексы уровня строки наблюдения из исходного кадра данных будут проходить, но без кубиков. В моем примере поддельные данные составляют 400 строк, поэтому есть ли способ получить доступ ко всем данным, связанным с каждым узлом (скажем, например, 23 наблюдения, которые сделали это с узлом где-то в дереве) внутри stemp ()?

# ---------- ANOVA EXAMPLE ------------
# *************************************
n <- 400
set.seed(36159)


# feature data
feat.c <- rbind(
  cbind(x1.c=rep(1, n/4), x2.c=rep(0, n/4), x3.c=runif(n/4, 0, 1))
  ,cbind(x1.c=rep(1, n/4), x2.c=rep(1, n/4), x3.c=runif(n/4, 0, 1))
  ,cbind(x1.c=rep(0, n/4), x2.c=rep(0, n/4), x3.c=runif(n/4, 0, 1))
  ,cbind(x1.c=rep(0, n/4), x2.c=rep(1, n/4), x3.c=runif(n/4, 0, 1))
)


# response variable
y.c <- c()
for(i in 1:n){
  y.c[i] <- rnorm(1, 1000*(feat.c[i, 'x1.c']==feat.c[i, 'x2.c']) + feat.c[i, 'x3.c'], sd=1)
}


# combine feater and response into data frame
dat <- data.frame(cbind(feat.c, y.c))


# initilization function required by custom rpart call
itemp <- function(y, offset, parms, wt) {
  if (is.matrix(y) && ncol(y) > 1)
    stop("Matrix response not allowed")
  if (!missing(parms) && length(parms) > 0)
    warning("parameter argument ignored")
  if (length(offset)) y <- y - offset
  sfun <- function(yval, dev, wt, ylevel, digits ) {
    paste(" mean=", format(signif(yval, digits)),
          ", MSE=" , format(signif(dev/wt, digits)),
          sep = '')
  }
  environment(sfun) <- .GlobalEnv
  list(y = c(y), parms = NULL, numresp = 1, numy = 1, summary = sfun)
}


# eval function required by custom rpart call
etemp <- function(y, wt, parms) {
  wmean <- sum(y*wt)/sum(wt)
  rss <- sum(wt*(y-wmean)^2)
  list(label = wmean, deviance = rss)
}


# splitting function required by custom rpart call
stemp <- function(y, wt, x, parms, continuous)
{
  # me exploring what I have access to that might help me extract the data for a particular node
  # print(ls())
  print(row.names(y))
  # print(row.names(x))

  # Center y
  n <- length(y)
  y <- y- sum(y*wt)/sum(wt)
  if (continuous) {
    # continuous x variable
    temp <- cumsum(y*wt)[-n]
    left.wt <- cumsum(wt)[-n]
    right.wt <- sum(wt) - left.wt
    lmean <- temp/left.wt
    rmean <- -temp/right.wt
    goodness <- (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2)
    list(goodness = goodness, direction = sign(lmean))
  } else {
    # Categorical X variable
    ux <- sort(unique(x))
    wtsum <- tapply(wt, x, sum)
    ysum <- tapply(y*wt, x, sum)
    means <- ysum/wtsum
    # For anova splits, we can order the categories by their means
    # then use the same code as for a non-categorical
    ord <- order(means)
    n <- length(ord)
    temp <- cumsum(ysum[ord])[-n]
    left.wt <- cumsum(wtsum[ord])[-n]
    right.wt <- sum(wt) - left.wt
    lmean <- temp/left.wt
    rmean <- -temp/right.wt
    list(goodness= (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2),
         direction = ux[ord])
  }
}


# package custom function into list and build tree
ulist <- list(eval = etemp, split = stemp, init = itemp)
mytree <- rpart(y.c ~ x1.c + x2.c +x3.c, data=dat, method = ulist, minsplit = 10)
...