Как мне построить переменную важность моей обученной модели дерева решений rpart? - PullRequest
0 голосов
/ 25 мая 2019

Я обучил модель, используя rpart, и я хочу сгенерировать график, отображающий значение переменных для переменных, которые он использовал для дерева решений, но я не могу понять, как.

Мне удалось извлечь значение переменной. Я пробовал ggplot, но никакой информации не видно. Я попытался использовать функцию plot (), но она дает мне только плоский график. Я также попробовал plot.default, который немного лучше, но все еще сейчас, что я хочу.

Вот обучение модели rpart:

argIDCART = rpart(Argument ~ ., 
                  data = trainSparse, 
                  method = "class")

Получил значение переменной во фрейм данных.

argPlot <- as.data.frame(argIDCART$variable.importance)

Вот часть того, что это печатает:

       argIDCART$variable.importance
noth                             23.339346
humanitarian                     16.584430
council                          13.140252
law                              11.347241
presid                           11.231916
treati                            9.945111
support                           8.670958

Я хотел бы построить график, который показывает имя переменной / функции и ее числовую значимость. Я просто не могу заставить это сделать это. Похоже, только один столбец. Я попытался разделить их, используя отдельную функцию, но тоже не могу этого сделать.

ggplot(argPlot, aes(x = "variable importance", y = "feature"))

Просто печатает пустым.

Другие сюжеты выглядят очень плохо.

plot.default(argPlot)

Похоже, что она строит точки, но не ставит имя переменной.

Ответы [ 2 ]

0 голосов
/ 25 мая 2019

Поскольку воспроизводимого примера не существует, я смонтировал свой ответ на основе собственного набора данных R, используя пакет ggplot2 и другие пакеты для обработки данных.

library(rpart)
library(tidyverse)
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
df <- data.frame(imp = fit$variable.importance)
df2 <- df %>% 
  tibble::rownames_to_column() %>% 
  dplyr::rename("variable" = rowname) %>% 
  dplyr::arrange(imp) %>%
  dplyr::mutate(variable = forcats::fct_inorder(variable))
ggplot2::ggplot(df2) +
  geom_col(aes(x = variable, y = imp),
           col = "black", show.legend = F) +
  coord_flip() +
  scale_fill_grey() +
  theme_bw()

enter image description here

ggplot2::ggplot(df2) +
  geom_segment(aes(x = variable, y = 0, xend = variable, yend = imp), 
               size = 1.5, alpha = 0.7) +
  geom_point(aes(x = variable, y = imp, col = variable), 
             size = 4, show.legend = F) +
  coord_flip() +
  theme_bw()

enter image description here

0 голосов
/ 25 мая 2019

Если вы хотите увидеть имена переменных, лучше всего использовать их в качестве меток на оси х.

plot(argIDCART$variable.importance, xlab="variable", 
    ylab="Importance", xaxt = "n", pch=20)
axis(1, at=1:7, labels=row.names(argIDCART))

Variable Importance

(Вам может потребоваться изменить размер окна, чтобы правильно видеть метки.)

Если у вас много переменных, вы можете повернуть имена переменных, чтобы они не перекрывались.

par(mar=c(7,4,3,2))
plot(argIDCART$variable.importance, xlab="variable", 
    ylab="Importance", xaxt = "n", pch=20)
axis(1, at=1:7, labels=row.names(argIDCART), las=2)

Rotated axis labels

Данные

argIDCART = read.table(text="variable.importance
noth                             23.339346
humanitarian                     16.584430
council                          13.140252
law                              11.347241
presid                           11.231916
treati                            9.945111
support                           8.670958", 
header=TRUE)
...