Попытка создать функцию для пакета, которая автоматически отображает ответ переменной для данной модели при наличии категориальных переменных. - PullRequest
3 голосов
/ 01 ноября 2019

Я пытаюсь создать функцию, которая отображает ответ выбранной переменной, когда хотя бы одна из переменных является категориальной.

Когда все ваши переменные являются числовыми, я обычно сохраняювсе остальные переменные имеют среднее значение, а затем изменяют целевую переменную, вот пример с mtcars:

library(tidyverse)

data("mtcars")

Сначала я изменю переменную am, чтобы она была категориальной переменной

mt2 <- mtcars %>% mutate(am = case_when(am == 0 ~ "Automatic", am == 1 ~ "Manual")) %>% select(mpg, am, wt, hp)

Тогда я покажу, что работает для меня

Это работает

для модели с только числовыми переменными, у меня нет проблем, например, с этой моделью

model1 <- lm(mpg ~ wt + hp, data = mt2)

Я могу использовать эту функцию, которую я сделал

Plot_Response <- function(Model, variable){

  # generate a data.frame with all the means copied 20 times
  Means <- Model$model %>% summarise_all(mean)
  Means <- Means[rep(seq_len(nrow(Means)), each = 20),]

  # Then generate a vector with a sequence from the min value to the max value of the variable
  MinMax <- Model$model %>% select(variable) %>% pull(variable) %>% range()
  MinMax <- seq(from = MinMax[1], to = MinMax[2], along.with = Means[,1])

  # Replace the column of the variable that we need to plot the response plot of by this sequence

  Means[colnames(Means)== as.character(variable)] <- MinMax

  ## Predict the fit and SE

  Means$Predicted <-predict(Model, newdata = Means)
  Means$SE <- predict(Model, newdata = Means, se.fit = T)$se.fit

  ## Plot the response
  result <- ggplot(Means, aes_string(x= variable, y = "Predicted")) + geom_ribbon(aes(ymax= Predicted + SE, ymin = Predicted - SE), fill = "grey") + geom_line() + theme_classic() 

  return(result)
}

Если я использую эту функцию с моделью выше, я могу сделать этот график

Plot_Response(Model = model1, variable = "wt")

enter image description here

вот когда я попадаю в беду

Конечно, если я попробую это, когда есть категориальная переменная, у меня возникнут проблемы, так как, если она пытается получить среднее значение категориальной величины дляфрейм данных это не удается:

model2 <- lm(mpg ~ wt + hp + am, data = mt2)

Если я попытаюсь

Plot_Response(Model = model2, variable = "wt")

я получу:

Error: variable 'am' was fitted with type "character" but type "numeric" was supplied

Итак, я попробовал следующее:

Plot_Response2 <- function(Model, variable){

  # First I get the names of all categorical variables
  Categoricals <- Model$model %>% mutate_if(is.factor, as.character) %>% select_if(is.character) %>% colnames()

  # generate a data.frame with all the means copied 20 times for each level

   Means <- Model$model %>% mutate_if(is.factor, as.character) %>% mutate_if(is.numeric, mean) %>% group_by_if(is.character) %>% summarise_if(is.numeric, mean) %>% ungroup()
   Means <- Means[rep(seq_len(nrow(Means)), each = 20),]  %>% arrange_if(is.character) %>% group_split(substitute(variable))

  return(Means)
}

Моя идея заключается в том, чтоФункция определит, какие переменные являются категориальными. Если я запрашиваю ответ числовой переменной, я получаю ответ этой переменной на каждом уровне категориальной переменной. Моя проблема до сих пор заключается в том, что когда я делаю group_split, он не распознает переменную

Что яожидайте:

Пример 1

Я ожидаю, что если я сделаю:

Plot_Response2(Model = model2, variable = "wt")

Я получу:

enter image description here

Я сделал для этого следующий код, но не смог добавить его в функцию:

Means <- model2$model %>% mutate_if(is.factor, as.character) %>% mutate_if(is.numeric, mean) %>% group_by_if(is.character) %>% summarise_if(is.numeric, mean) %>% ungroup()
Means <- Means[rep(seq_len(nrow(Means)), each = 20),]  %>% arrange_if(is.character) %>% group_split(am)

MinMax <- model2$model %>% select(wt) %>% pull(wt) %>% range()
MinMax <- seq(from = MinMax[1], to = MinMax[2], length.out = 20)

for(i in 1:length(Means)){
  Means[[i]]$wt <- MinMax
}

Means <- bind_rows(Means)
Means$Predicted <- predict(model2, Means)
Means$SE <- predict(model2, Means, se.fit = T)$se.fit

ggplot(Means, aes(x = wt, y = Predicted)) + geom_ribbon(aes(ymax = Predicted + SE, ymin = Predicted - SE, fill = am), alpha = 0.5) + geom_line(aes(color = am)) + theme_classic()

Пример 1

Я ожидаю, что если яdo:

Plot_Response2(Model = model2, variable = "am")

Я получу:

enter image description here

Опять же для этого я использовал этот код, который я не могуПохоже, что вместе с функцией 2

Means <- model2$model %>% mutate_if(is.factor, as.character) %>% mutate_if(is.numeric, mean) %>% group_by_if(is.character) %>% summarise_if(is.numeric, mean) %>% ungroup()
Means <- Means[rep(seq_len(nrow(Means)), each = 20),]  %>% arrange_if(is.character) %>% group_split(am)

Means <- bind_rows(Means)
Means$Predicted <- predict(model2, Means)
Means$SE <- predict(model2, Means, se.fit = T)$se.fit

ggplot(Means, aes(x = am, y = Predicted)) + geom_errorbar(aes(ymin = Predicted - SE, ymax = Predicted + SE)) + geom_point() + theme_classic()

Любая помощь или предложение очень ценится, и любые необходимые разъяснения я отвечу.

Спасибо

Ответы [ 2 ]

3 голосов
/ 05 ноября 2019

Вот версия, которая использует несколько дополнительных функций для упрощения работы.

Plot_Response <- function(Model, variable, N=20) {
  model_data <- model.frame(Model)
  stopifnot(variable %in% names(model_data))

  # get all variables we need to dummy values for
  all_vars <- model_data %>% select(-one_of(variable))
  num_vars <- all_vars %>% select_if(is.numeric) %>% summarize_all(mean)
  cat_vars <- all_vars %>% select_if(Negate(is.numeric)) %>% purrr::map(unique)

  resp_var <- model_data %>% pull(variable) 
  if(is.numeric(resp_var)) {
    resp_vals <- seq(min(resp_var), max(resp_var), length.out=N)
  } else {
    resp_vals <- unique(resp_var)
  }

  new_data <- tidyr::crossing(num_vars, !!!cat_vars, !!variable:=resp_vals)

  pred <- broom::augment(Model, newdata = new_data, se_fit=TRUE)

  ## Plot the response
  my_aes <- aes(x= !!sym(variable), y = .fitted)
  if (length(cat_vars)==1) {
    my_aes[["fill"]] <- sym(names(cat_vars))
  } else if (length(cat_vars)>1) {
    my_aes[["fill"]] <- quo(interaction(!!!syms(names(cat_vars))))
  }
  range_aes <- aes(ymax= .fitted + .se.fit, ymin = .fitted - .se.fit)
  result <- ggplot(pred, my_aes) + theme_classic() + ylab("Predicted")
  if(is.numeric(resp_var)) {
    result + 
      (if (length(cat_vars)>0) {
        geom_ribbon(range_aes) 
      } else {
        geom_ribbon(range_aes, fill="grey")
      }) + 
      geom_line()
  } else {
    result + 
      geom_errorbar(range_aes) + 
      geom_point() 
  } 

}

Это работает для обоих перечисленных вами случаев

model1 <- lm(mpg ~ wt + hp + am, data = mt2)
Plot_Response(model1, "wt")
Plot_Response(model1, "am")
0 голосов
/ 05 ноября 2019

Я полагаю, что вы описываете точно проблему, решаемую пакетом DescTools . Первая строка описания пакета говорит об этом очень хорошо:

DescTools - это обширная коллекция различных базовых статистических функций и удобных упаковщиков, недоступных в базовой системе R для эффективного описания данных.

Я не фанат загрузки ряда пакетов, чтобы выполнить работу в R. Однако я делаю исключение для этого. Я думаю, что обширная коллекция инструментов Андри Синьорелла действительно выдающаяся. Могут быть конфликты между функциями, определенными в DescTools, и функциями в tidyverse, поэтому я формулирую свой ответ, не прибегая к tidyverse.

# DescTools needs to be available
  if (!require(DescTools)) {
    install.packages("DescTools")
  }
  library(DescTools)

# Create factors in mtcars
  mt3 <- mtcars
  mt3$am <- factor(mt3$am, labels = c("man", "auto"))
  mt3$vs <- factor(mt3$vs, labels = c("v", "str"))

Примеры, приведенные в вопросе, представляют собой варианты построения графиков. mpg как описано различными другими переменными. Если цель здесь состоит в том, чтобы на самом деле написать для этой цели общую функцию, то этот ответ бесполезен. Однако, если цель состоит в том, чтобы удобно визуализировать переменные, как указано: «Я хочу получить функцию, которая автоматически находит, какие переменные являются категориальными, которые являются непрерывными, и соответственно отображает ответ», тогда я думаю, что DescTools - отличный ответ!

DescTools не является инструментом построения графиков общего назначения. Я не верю, что вы можете построить одновременные регрессии, как это сделано с кодом ggplot. Тем не менее, он отлично показывает разумный график для выбранных переменных. Сначала запрошены два примера:

# mpg as a function of weight
  dev.new(width = 6, height = 4.5)
  opar <- par(mfrow = c(1, 2))
  Desc(mpg ~ wt, mt3, main = "Manual", subset = am == "man")
  Desc(mpg ~ wt, mt3, main = "Automatic", subset = am == "auto")
  par(opar)

mpg by wt

# mpg as a function of transmission
  Desc(mpg ~ am, mt3)

mpg vs transmission
И еще два примеранасколько просто это может быть для взаимодействующих факторов и даже для одной непрерывной переменной.

  Desc(mpg ~ am:vs, mt3)
  Desc(mt3$qsec)  

mpg vs interaction quarter mile time

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...