Получение правильного терминального узла для каждой записи в обучающем наборе ctree - PullRequest
0 голосов
/ 22 ноября 2018

У меня есть следующий код, который перечисляет все терминальные узлы для ctree.Я хотел бы добавить для каждой записи в обучающем наборе (airq, который я обучал этому в этом случае) свой номер конечного узла.Поэтому я добавлю airq вызов столбца TN (терминальный узел), который содержит номер его терминального узла.

CtreePathFunc <- function (ct, data) {

  ResulTable <- data.frame(Node = character(), Path = character())

  for(Node in unique(where(ct))){
    # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
    NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])


    # Getting the weigths for that node
    NodeWeights <- nodes(ct, Node)[[1]]$weights


    # Finding the path
    Path <- NULL
    for (i in NonTerminalNodes){
      if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
    }

    # Finding the splitting creteria for that path
    Path2 <- SB <- NULL

    for(i in 1:length(Path)){
      if(i == length(Path)) {
        n <- nodes(ct, Node)[[1]]
      } else {n <- nodes(ct, Path[i + 1])[[1]]}

      if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
        SB <- "<="
      } else {SB <- ">"}
      Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
                                    SB,
                                    as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
                     collapse = ", ")
    }

    # Output
    ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
  }
  return(ResulTable)
}


library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)

> Result
  Node                               Path
1    5 Temp <= 82, Wind > 6.9, Temp <= 77
2    3            Temp <= 82, Wind <= 6.9
3    6  Temp <= 82, Wind > 6.9, Temp > 77
4    9             Temp > 82, Wind > 10.3
5    8            Temp > 82, Wind <= 10.3
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...