Я изучаю написание пользовательской функции разделения в Rpart, где мне нужен доступ ко всему поднабору данных, оцениваемых в данном узле внутри функции разделения (в соответствии с пользовательскими рекомендациями по rpart, стр. 4-6 https://cran.r-project.org/web/packages/rpart/vignettes/usercode.pdf ). Идея состоит в том, чтобы оценить комбинации разбиений по другим ковариатам (скажем, x2 и x3) в дополнение к тому, что ковариата (скажем, x1) вызывается для любого конкретного узла. Из того, что я могу сказать, функция расщепления принимает аргументы (среди прочих) y и x, где y - это подмножество ответов в конкретном узле, а x - соответствующий ковариат, оцениваемый в этом узле, но я не вижу способа доступа остальная часть набора данных при этом вызове.
Прикрепленный код основан на поддельном наборе данных, который я составил, и это не попытка найти решение, а больше я исследую, какие объекты доступны для меня в рамках каждый рекурсивный шаг, так что я могу изменить пользовательскую функцию разделения, чтобы создать то, что я хочу, не красиво, но это делает работу. Я скопировал пользовательский пример из вышеупомянутой ссылки и изменил stemp с помощью print (), чтобы посмотреть, что я смогу сделать. Я попытался напечатать имена строк векторов y и x, надеясь, что индексы уровня строки наблюдения из исходного кадра данных будут проходить, но без кубиков. В моем примере поддельные данные составляют 400 строк, поэтому есть ли способ получить доступ ко всем данным, связанным с каждым узлом (скажем, например, 23 наблюдения, которые сделали это с узлом где-то в дереве) внутри stemp ()?
# ---------- ANOVA EXAMPLE ------------
# *************************************
n <- 400
set.seed(36159)
# feature data
feat.c <- rbind(
cbind(x1.c=rep(1, n/4), x2.c=rep(0, n/4), x3.c=runif(n/4, 0, 1))
,cbind(x1.c=rep(1, n/4), x2.c=rep(1, n/4), x3.c=runif(n/4, 0, 1))
,cbind(x1.c=rep(0, n/4), x2.c=rep(0, n/4), x3.c=runif(n/4, 0, 1))
,cbind(x1.c=rep(0, n/4), x2.c=rep(1, n/4), x3.c=runif(n/4, 0, 1))
)
# response variable
y.c <- c()
for(i in 1:n){
y.c[i] <- rnorm(1, 1000*(feat.c[i, 'x1.c']==feat.c[i, 'x2.c']) + feat.c[i, 'x3.c'], sd=1)
}
# combine feater and response into data frame
dat <- data.frame(cbind(feat.c, y.c))
# initilization function required by custom rpart call
itemp <- function(y, offset, parms, wt) {
if (is.matrix(y) && ncol(y) > 1)
stop("Matrix response not allowed")
if (!missing(parms) && length(parms) > 0)
warning("parameter argument ignored")
if (length(offset)) y <- y - offset
sfun <- function(yval, dev, wt, ylevel, digits ) {
paste(" mean=", format(signif(yval, digits)),
", MSE=" , format(signif(dev/wt, digits)),
sep = '')
}
environment(sfun) <- .GlobalEnv
list(y = c(y), parms = NULL, numresp = 1, numy = 1, summary = sfun)
}
# eval function required by custom rpart call
etemp <- function(y, wt, parms) {
wmean <- sum(y*wt)/sum(wt)
rss <- sum(wt*(y-wmean)^2)
list(label = wmean, deviance = rss)
}
# splitting function required by custom rpart call
stemp <- function(y, wt, x, parms, continuous)
{
# me exploring what I have access to that might help me extract the data for a particular node
# print(ls())
print(row.names(y))
# print(row.names(x))
# Center y
n <- length(y)
y <- y- sum(y*wt)/sum(wt)
if (continuous) {
# continuous x variable
temp <- cumsum(y*wt)[-n]
left.wt <- cumsum(wt)[-n]
right.wt <- sum(wt) - left.wt
lmean <- temp/left.wt
rmean <- -temp/right.wt
goodness <- (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2)
list(goodness = goodness, direction = sign(lmean))
} else {
# Categorical X variable
ux <- sort(unique(x))
wtsum <- tapply(wt, x, sum)
ysum <- tapply(y*wt, x, sum)
means <- ysum/wtsum
# For anova splits, we can order the categories by their means
# then use the same code as for a non-categorical
ord <- order(means)
n <- length(ord)
temp <- cumsum(ysum[ord])[-n]
left.wt <- cumsum(wtsum[ord])[-n]
right.wt <- sum(wt) - left.wt
lmean <- temp/left.wt
rmean <- -temp/right.wt
list(goodness= (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2),
direction = ux[ord])
}
}
# package custom function into list and build tree
ulist <- list(eval = etemp, split = stemp, init = itemp)
mytree <- rpart(y.c ~ x1.c + x2.c +x3.c, data=dat, method = ulist, minsplit = 10)