Пользовательская функция игнорирует оператор if при передаче ему параметра - PullRequest
0 голосов
/ 24 января 2020

Я пытаюсь разработать функцию, позволяющую использовать множество моделей для построения границы решения бинарной классификационной модели. На данный момент мне удалось получить «две» модели, работающие на knn и Logistic модели. Когда я запускаю функцию для модели Logisti c, используя:

mydat <- decisionplot(model = "Logistic",
                      data = data,
                      var1 = var1,
                      var2 = var2,
                      class = class)

Функция работает, однако, когда я запускаю ту же функцию, но использую KNN, вместо нее выдается оператор else no model selected:

mydat <- decisionplot(model = "KNN",
                      data = data,
                      var1 = var1,
                      var2 = var2,
                      class = class)
mydat

Однако KNN является выбранным if оператором в функции if (is.null(model) || model == "KNN"){....

Где я здесь ошибаюсь?

Код и данные:

library(rlang)

data <- iris %>% 
  filter(Species != "setosa")
var1 = "Sepal.Length"
var2 = "Sepal.Width"
class = "Species"
resolution = 0.1
model = "KNN"

decisionplot <- function(model = NULL, data, var1, var2, class,
                         predict_type = "class", resolution = 0.1){
  X_train_data = data[, c(eval_tidy(var1), eval_tidy(var2))]
  Y_train_data <- data[, c(eval_tidy(class))]
  XY_train_data <- cbind(Y_train_data, X_train_data) %>% 
    setNames(c("Y", "X1", "X2"))

  grid <- expand.grid(
    x = seq(
      min(X_train_data[, 1] - 1),
      max(X_train_data[, 1] + 1),
      by = resolution
      ),
    y = seq(
      min(X_train_data[, 2] - 1),
      max(X_train_data[, 2] + 1), 
      by = resolution
      )
  )
  if (is.null(model) || model == "KNN"){
    mini_model <- class::knn(X_train_data, grid, Y_train_data, k = 2, prob = TRUE)
    mini_model_probs <- attr(mini_model, "prob")

    data_plot <- bind_rows(
      mutate(
        grid,
        prob = mini_model_probs,
        class = "Non-Bankrupt",
        prob_class = ifelse(
          mini_model == 0, 1, 0
          )
        ),
      mutate(
        grid,
        prob = mini_model_probs,
        class = "Bankrupt",
        prob_class = ifelse(
          mini_model == 0, 1, 0
          )
        )
      )
    #return(list(data_plot, mini_model, mini_model_probs))
  } 
 if (is.null(model) || model == "Logistic"){
    mini_model <- glm(Y ~ X1 + X2,
                      data = XY_train_data, family = "binomial")
    mini_model_probs <- predict(object = mini_model, newdata = grid %>% 
                                setNames(c("X1", "X2")), type = 'response')
    data_plot <- bind_rows(
      mutate(
        grid,
        prob = mini_model_probs,
        class = "Non-Bankrupt",
        prob_class = ifelse(
          mini_model_probs >= 0.5, 1, 0
        )
      ),
      mutate(
        grid,
        prob = mini_model_probs,
        class = "Bankrupt",
        prob_class = ifelse(
          mini_model_probs >= 0.5, 1, 0
        )
      )
    )
    #return(list(data_plot, mini_model, mini_model_probs))
  }
else{
    return("no model selected")
}
  return(list(
    X_train_data,
    Y_train_data,
    mini_model,
    mini_model_probs,
    data_plot,
    grid,
    XY_train_data))
}


mydat <- decisionplot(model = "Logistic",
                      data = data,
                      var1 = var1,
                      var2 = var2,
                      class = class)

mydat
###
ggplot() +
  geom_point(aes(x = x, y = y, colour = class, size = prob_class),
             data = mydat[[5]]) +
  scale_size(range=c(0.2, 1)) +
  geom_contour(aes(x = x, y = y, z = prob_class, group = factor(class), color = factor(class)),
               bins = 2,
               data = mydat[[5]]) +
  geom_point(aes(x = X1, y = X2, color = factor(Y)),
             size = 3,
             alpha = 0.2,
             data = mydat[[7]]) +
  geom_point(aes(x = X1, y = X2),
             size = 3,
             shape = 1,
             alpha = 0.2,
             data = mydat[[7]])

1 Ответ

1 голос
/ 25 января 2020

Может быть трудно сказать, но скобки могут быть отключены, когда вы добавили else.

Вы также можете создать отдельную функцию при создании data_plot - по крайней мере, это поможет вам визуализировать, как if/else может быть применено.

Я надеюсь, что это может быть полезно для продвижения вперед.

data_plot_fn <- function(grid, mini_model_probs, prob_class){
  bind_rows(
    mutate(
      grid,
      prob = mini_model_probs,
      class = "Non-Bankrupt",
      prob_class = prob_class
    ),
    mutate(
      grid,
      prob = mini_model_probs,
      class = "Bankrupt",
      prob_class = prob_class
    )
  )
}

decisionplot <- function(model = NULL, data, var1, var2, class,
                         predict_type = "class", resolution = 0.1){
  X_train_data = data[, c(eval_tidy(var1), eval_tidy(var2))]
  Y_train_data <- data[, c(eval_tidy(class))]
  XY_train_data <- cbind(Y_train_data, X_train_data) %>% 
    setNames(c("Y", "X1", "X2"))

  grid <- expand.grid(
    x = seq(
      min(X_train_data[, 1] - 1),
      max(X_train_data[, 1] + 1),
      by = resolution
    ),
    y = seq(
      min(X_train_data[, 2] - 1),
      max(X_train_data[, 2] + 1), 
      by = resolution
    )
  )

  if (model == "KNN"){
    message("KNN Model")
    mini_model <- class::knn(X_train_data, grid, Y_train_data, k = 2, prob = TRUE)
    mini_model_probs <- attr(mini_model, "prob")
    prob_class = ifelse(mini_model == 0, 1, 0)
    data_plot <- data_plot_fn(grid, mini_model_probs, prob_class) 
  } else if (model == "Logistic"){
    message("Logistic Model")
    mini_model <- glm(Y ~ X1 + X2, data = XY_train_data, family = "binomial")
    mini_model_probs <- predict(object = mini_model, newdata = grid %>% 
                                  setNames(c("X1", "X2")), type = 'response')
    prob_class = ifelse(mini_model_probs >= 0.5, 1, 0)
    data_plot <- data_plot_fn(grid, mini_model_probs, prob_class) 
  } else {
    warning("No model selected")
    return(NULL)
  }
  return(list(
    X_train_data,
    Y_train_data,
    mini_model,
    mini_model_probs,
    data_plot,
    grid,
    XY_train_data))
}
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...