У меня есть дерево с именем mytree
, которое выглядит следующим образом:
В R оно хранится в виде списка:
mytree <- list(left = structure(list(y = -10, x = 10, grad = -10.5, sim_score = 110.25,
value = -10.5, criterion = "x < 15"), row.names = 1L, class = "data.frame"),
right = list(left = list(left = structure(list(y = 7, x = 20,
grad = 6.5, sim_score = 42.25, value = 6.5, criterion = "x < 22.5"), row.names = 2L, class = "data.frame"),
right = structure(list(y = 8, x = 25, grad = 7.5, sim_score = 56.25,
value = 7.5, criterion = "x >= 22.5"), row.names = 3L, class = "data.frame"),
root = list(root = structure(list(y = c(7, 8), x = c(20,
25), grad = c(6.5, 7.5), sim_score = c(98, 98), value = c(7,
7), criterion = c("x < 30", "x < 30")), row.names = 2:3, class = "data.frame"),
gain = 0.5)), right = structure(list(y = -7, x = 35,
grad = -7.5, sim_score = 56.25, value = -7.5, criterion = "x >= 30"), row.names = 4L, class = "data.frame"),
root = list(root = structure(list(y = c(7, 8, -7), x = c(20,
25, 35), grad = c(6.5, 7.5, -7.5), sim_score = c(14.0833333333333,
14.0833333333333, 14.0833333333333), value = c(2.16666666666667,
2.16666666666667, 2.16666666666667), criterion = c("x >= 15",
"x >= 15", "x >= 15")), row.names = 2:4, class = "data.frame"),
gain = 140.166666666667)), root = list(root = structure(list(
y = c(-10, 7, 8, -7), x = c(10, 20, 25, 35), grad = c(-10.5,
6.5, 7.5, -7.5), sim_score = c(4, 4, 4, 4)), row.names = c(NA,
-4L), class = "data.frame"), gain = 120.333333333333))
, что выглядит так
$left
y x grad sim_score value criterion
1 -10 10 -10.5 110.25 -10.5 x < 15
$right
$right$left
$right$left$left
y x grad sim_score value criterion
2 7 20 6.5 42.25 6.5 x < 22.5
$right$left$right
y x grad sim_score value criterion
3 8 25 7.5 56.25 7.5 x >= 22.5
$right$left$root
$right$left$root$root
y x grad sim_score value criterion
2 7 20 6.5 98 7 x < 30
3 8 25 7.5 98 7 x < 30
$right$left$root$gain
[1] 0.5
$right$right
y x grad sim_score value criterion
4 -7 35 -7.5 56.25 -7.5 x >= 30
$right$root
$right$root$root
y x grad sim_score value criterion
2 7 20 6.5 14.08333 2.166667 x >= 15
3 8 25 7.5 14.08333 2.166667 x >= 15
4 -7 35 -7.5 14.08333 2.166667 x >= 15
$right$root$gain
[1] 140.1667
$root
$root$root
y x grad sim_score
1 -10 10 -10.5 4
2 7 20 6.5 4
3 8 25 7.5 4
4 -7 35 -7.5 4
$root$gain
[1] 120.3333
Разделения хранятся в criterion
, а значения выхода сохраняются в value
.
Учитывая новую точку данных, x = 5
, я хотел бы запросить mytree
и посмотреть, под какой листовой узел попадает этот экземпляр. Для x = 5
моя функция должна вывести значение -10.5
, потому что 5 < 15
. Аналогично, если x = 25
, то он должен оказаться в листе со значением 7.5
. Вот еще несколько примеров того, что я бы хотел, чтобы моя функция pred_tree
выводила:
newdata <- data.frame(x = c(5, 19, 18, 30))
> pred_tree(tree = mytree, newdata = newdata)
[1] -10.5
[2] 6.5
[3] 6.5
[4] -7.5
Вот что у меня есть на данный момент:
pred_tree <- function(tree, newdata){
for(i in length(tree)){
# Check if this is a leaf
if(length(tree[[i]]) == 1){
# Check criterion
if(eval(parse(text=tree[[i]]$criterion))){
# Return value of leaf
return(tree[[i]]$value[1])
}
}else if(length(tree[[i]]) > 1){
for(j in 1:length(tree[[i]])){
if(length(tree[[i]][[j]]) == 1){
# Check criterion
if(eval(parse(text=tree[[i]][[j]]$criterion))){
# Return value of leaf
return(tree[[i]][[j]]$value[1])
}
}
}
}
}
}
pred_tree(tree, newdata = newdata)
К сожалению, эта функция не возвращает правильный вывод. Кроме того, это довольно неуклюже и может быть очень медленным, если мне нужно выполнить много запросов. Я предполагаю, что использование рекурсивного алгоритма имело бы больший смысл вместо использования вложенных циклов for. Может ли кто-нибудь указать мне правильное направление?
@@@@@@@@@@@@@ РЕДАКТИРОВАТЬ @@@@@@@@@@@@@
mytree3 <- list(left = list(left = structure(list(y = -10, x = 10, grad = 0,
sim_score = 0, value = 0, criterion = "x < 15"), row.names = 1L, class = "data.frame"),
right = structure(list(y = 7, x = 20, grad = -0.5, sim_score = 0.25,
value = -0.5, criterion = "x >= 15"), row.names = 2L, class = "data.frame"),
root = list(root = structure(list(y = c(-10, 7), x = c(10,
20), grad = c(0, -0.5), sim_score = c(0.125, 0.125), value = c(-0.25,
-0.25), criterion = c("x < 22.5", "x < 22.5")), row.names = 1:2, class = "data.frame"),
gain = 0.125)), right = list(left = structure(list(y = 8,
x = 25, grad = 0.5, sim_score = 0.25, value = 0.5, criterion = "x < 30"), row.names = 3L, class = "data.frame"),
right = structure(list(y = -7, x = 35, grad = 0, sim_score = 0,
value = 0, criterion = "x >= 30"), row.names = 4L, class = "data.frame"),
root = list(root = structure(list(y = c(8, -7), x = c(25,
35), grad = c(0.5, 0), sim_score = c(0.125, 0.125), value = c(0.25,
0.25), criterion = c("x >= 22.5", "x >= 22.5")), row.names = 3:4, class = "data.frame"),
gain = 0.125)), root = list(root = structure(list(y = c(-10,
7, 8, -7), x = c(10, 20, 25, 35), grad = c(0, -0.5, 0.5, 0),
sim_score = c(0, 0, 0, 0), value = c(0, 0, 0, 0)), row.names = c(NA,
-4L), class = "data.frame"), gain = 0.25))
Выполнение следующего не дало правильного вывода
pred_tree(tree = mytree3, newdata = newdata)