Вес в cforest 'tree' в сумме превышает размер выборки - PullRequest
0 голосов
/ 26 июня 2018

Я впервые задаю вопрос здесь, поэтому, пожалуйста, будьте добры ... и я прошу прощения, если это не то место, где можно задать этот вопрос (грань между программированием на R и вопросом статистики может быть немного размыта мне) - если так, я с удовольствием попробую stackexchange.

Я запускаю cforest в R (пакет участника), чтобы предсказать числовую переменную ответа, и использую превосходный обходной путь get_cTree, предложенный @Marco Sandri здесь: https://stackoverflow.com/a/34534978/9989544 создать дерево, чтобы попытаться понять правила, которые оно использует для разделения (это меня интересует, хотя мой основной фокус - переменная важность).

Я ожидал, что веса по всем узлам суммируются с моим общим размером выборки, что и происходит, если я запускаю одно ctree.

Однако, что на самом деле происходит при использовании кода get_cTree Марко Сандри, так это то, что веса нескольких пар узлов суммируются с размером моей выборки, а оставшиеся веса вообще не суммируются с размером моей выборки. Общий вес превышает мой общий размер выборки.

Является ли это естественным следствием попытки вывести дерево из условного леса - т. Е. Оно действительно не разделяет данные на отдельные узлы? - или это можно решить с помощью программирования?

Вот пример (код get_cTree от Марко Сандри). Для набора данных радужной оболочки n = 150. Сумма весов для узлов, которые я получаю для cforest, равна 566, а это 150 с использованием ctree (пакет party).

library(party)

update_tree <- function(x, dt) {
  x <- update_weights(x, dt)
  if(!x$terminal) {
    x$left <- update_tree(x$left, dt)
    x$right <- update_tree(x$right, dt)   
  } 
  x
}

update_weights <- function(x, dt) {
  splt <- x$psplit
  spltClass <- attr(splt,"class")
  spltVarName <- splt$variableName
  spltVar <- dt[,spltVarName]
  spltVarLev <- levels(spltVar)
  if (!is.null(spltClass)) {
    if (spltClass=="nominalSplit") {
      attr(x$psplit$splitpoint,"levels") <- spltVarLev   
      filt <- spltVar %in% spltVarLev[as.logical(x$psplit$splitpoint)] 
    } else {
      filt <- (spltVar <= splt$splitpoint)
    }
    x$left$weights <- as.numeric(filt)
    x$right$weights <- as.numeric(!filt)
  }
  x
}

get_cTree <- function(cf, k=1) {
  dt <- cf@data@get("input")
  tr <- party:::prettytree(cf@ensemble[[k]], names(dt))
  tr_updated <- update_tree(tr, dt)
  new("BinaryTree", tree=tr_updated, data=cf@data, responses=cf@responses, 
      cond_distr_response=cf@cond_distr_response, predict_response=cf@predict_response)
}

attach(iris)

SepalLength <- as.numeric(iris$Sepal.Length)

SepalWidth <- as.numeric(iris$Sepal.Width)

PetalLength <- as.numeric(iris$Petal.Length)

PetalWidth <- as.numeric(iris$Petal.Width)

Species <- as.factor(iris$Species)

mtry=ceiling(sqrt(4))

set.seed(1)

iris_cforest <- cforest(PetalLength~SepalLength+SepalWidth+PetalWidth+Species,controls=cforest_unbiased(ntree=1000,mtry=mtry))

iristree <- get_cTree(iris_cforest)

iristree

plot(iristree)

set.seed(1)

iris_ctree <- ctree(PetalLength~SepalLength+SepalWidth+PetalWidth+Species)

iris_ctree

plot(iris_ctree)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...