Наложение дерева на набор данных с помощью ggcontour - PullRequest
3 голосов
/ 18 апреля 2019

Мне поручено приспособить модель дерева классификации к наблюдениям ниже. Затем я должен поместить дерево поверх существующих данных. Было рекомендовано использовать p + geom_contour(....), но я мало знаком с ggplot.

Код, который мне указан ниже.

Я могу довольно легко приспособить модель дерева к данным, но построение графика дает только дерево решений. Как я могу наложить модель дерева на существующий график, используя geom_contour?

library(tidyverse)
set.seed(1234)
dat <- tibble(
    x1 = rnorm(100),
    x2 = rnorm(100)
) %>% mutate(y = as_factor(ifelse(x1^2 + x2^2 > 1.39, "A", "B")))

circlepts <- tibble(theta = seq(0, 2*pi, length = 100)) %>%
    mutate(x = sqrt(1.39) * sin(theta), y = sqrt(1.39) * cos(theta))

p <- ggplot(dat) + geom_point(aes(x1, x2, color = y)) + coord_fixed() +
    geom_polygon(data = circlepts, aes(x, y), color = "blue", fill = NA)
p

Чтобы приспособить модель дерева к данным, я набираю

library(tree)
tree_fit <- tree(y~., dat)

Наложением будет просто дерево решений, соответствующее данным, например, как то так (грубо прорисовано в MS Paint)

enter image description here

1 Ответ

2 голосов
/ 18 апреля 2019

Я не думаю, что geom_contour - это способ сделать это, но вы можете получить координаты для отрезков линии из базового фрейма данных tree_fit и выполнить некоторый спор, чтобы постепенно ограничить каждый сегмент неподвижным «активная» площадь на участке:

tree.df.segment <- tree_fit$frame %>% 
  rownames_to_column() %>% 
  mutate(rowname = as.integer(rowname),
         depth = tree:::tree.depth(rowname),
         split = splits[, 1] %>%
           gsub("<|>", "", .) %>%
           as.numeric()) %>%

  arrange(depth, rowname) %>%
  mutate(leaf.position = case_when(lead(depth) > depth & lead(var) == "<leaf>" ~ "left",
                                   lead(depth) > depth & lead(var) != "<leaf>" ~ "right",
                                   TRUE ~ NA_character_)) %>%
  fill(leaf.position, .direction = "up") %>%
  filter(var != "<leaf>") %>%
  select(depth, var, split, leaf.position) %>%

  # define basic segment coordinates
  mutate(x = -Inf, xend = Inf, y = -Inf, yend = Inf,
         xmin = -Inf, xmax = Inf, ymin = -Inf, ymax = Inf) %>%

  # modify coordinates of segment / active area based on split
  mutate(x    = ifelse(var == "x1", split, x),
         xend = ifelse(var == "x1", split, xend),
         y    = ifelse(var == "x2", split, y),
         yend = ifelse(var == "x2", split, yend),
         xmin = ifelse(var == "x1" & leaf.position ==  "left", split, xmin),
         xmax = ifelse(var == "x1" & leaf.position == "right", split, xmax),
         ymin = ifelse(var == "x2" & leaf.position ==  "left", split, ymin),
         ymax = ifelse(var == "x2" & leaf.position == "right", split, ymax)) %>%
  # shrink active area progressively as depth increases
  mutate(xmin = cummax(xmin), xmax = cummin(xmax),
         ymin = cummax(ymin), ymax = cummin(ymax)) %>%
  # limit segment coordinates to within active area
  mutate(x = pmax(x, xmin), xend = pmin(xend, xmax),
         y = pmax(y, ymin), yend = pmin(yend, ymax))

p + 
  geom_segment(data = tree.df.segment,
               aes(x = x, xend = xend, y = y, yend = yend))

plot


Кроме того (поскольку я думаю, что после этого будет задан вопрос), мы можем заштриховать каждую область, соответствующую конечному листу, в виде прямоугольника, используя geom_rect. Это потребует некоторых дополнительных споров.

tree.df.rect <- tree.df.segment %>%
  mutate(depth = depth + 1) %>%
  select(-c(x, xend, y, yend)) %>%
  mutate_at(vars(xmin, xmax, ymin, ymax), list(rect = lag)) %>%
  mutate_at(vars(xmin_rect, ymin_rect), ~ifelse(is.na(.), -Inf, .)) %>%
  mutate_at(vars(xmax_rect, ymax_rect), ~ifelse(is.na(.), Inf, .)) %>%
  mutate(xmin_rect = ifelse(var == "x1" & leaf.position == "right", split, xmin_rect),
         xmax_rect = ifelse(var == "x1" & leaf.position ==  "left", split, xmax_rect),
         ymin_rect = ifelse(var == "x2" & leaf.position == "right", split, ymin_rect),
         ymax_rect = ifelse(var == "x2" & leaf.position ==  "left", split, ymax_rect)) %>%
  # add label for each rect
  full_join(tree_fit$frame %>%
              rownames_to_column() %>%
              mutate(rowname = as.integer(rowname),
                     depth = tree:::tree.depth(rowname),
                     split = splits[, 1] %>%
                       gsub("<|>", "", .) %>%
                       as.numeric()) %>%
              filter(var == "<leaf>") %>%
              select(depth, rowname, yval) %>%
              arrange(depth, rowname))
# since last split is associated with two rectangles, determine which is the last 'active'
# one in order to assign the labels correctly (doesn't matter in this case since the last
# two labels are both 'B', but this should apply more generally)
if(tree.df.rect %>% filter(depth == max(depth)) %>% pull(leaf.position) %>% unique() == "left") {
  tree.df.rect[nrow(tree.df.rect), c("xmin_rect", "xmax_rect", "ymin_rect", "ymax_rect")] <-
    tree.df.rect[nrow(tree.df.rect), c("xmin", "xmax", "ymin", "ymax")]
} else {
  tree.df.rect[nrow(tree.df.rect)-1, c("xmin_rect", "xmax_rect", "ymin_rect", "ymax_rect")] <-
    tree.df.rect[nrow(tree.df.rect)-1, c("xmin", "xmax", "ymin", "ymax")]
}
tree.df.rect <- tree.df.rect %>%
  select(depth, yval, xmin_rect, xmax_rect, ymin_rect, ymax_rect)

# combine into one data frame
tree.df <- full_join(
  tree.df.rect %>%
    select(depth, yval, xmin_rect, xmax_rect, ymin_rect, ymax_rect),
  tree.df.segment %>%
    select(depth, x, xend, y, yend)
)

p.shaded <- ggplot(data = tree.df) + 
  geom_point(data = dat, aes(x1, x2, color = y)) + 
  geom_polygon(data = circlepts, aes(x, y), color = "blue", fill = NA) + 
  geom_rect(aes(xmin = xmin_rect, xmax = xmax_rect,
                ymin = ymin_rect, ymax = ymax_rect,
                fill = yval),
            alpha = 0.25) +
  geom_segment(aes(x = x, xend = xend, y = y, yend = yend)) +
  coord_fixed() +
  labs(color = "", fill = "") +
  scale_fill_discrete(breaks = c("A", "B"))

p.shaded

shaded plot

Который может быть легко расширен в анимированную форму:

library(gganimate)

p.anim <- p.shaded +  
  transition_states(depth) +
  shadow_mark() +
  enter_fade() +
  labs(title = "{closest_state}")

animate(p.anim, nframes = 10, fps = 1)

animated plot

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...