Я пытаюсь разработать функцию, позволяющую использовать множество моделей для построения границы решения бинарной классификационной модели. На данный момент мне удалось получить «две» модели, работающие на knn
и Logistic
модели. Когда я запускаю функцию для модели Logisti c, используя:
mydat <- decisionplot(model = "Logistic",
data = data,
var1 = var1,
var2 = var2,
class = class)
Функция работает, однако, когда я запускаю ту же функцию, но использую KNN
, вместо нее выдается оператор else
no model selected
:
mydat <- decisionplot(model = "KNN",
data = data,
var1 = var1,
var2 = var2,
class = class)
mydat
Однако KNN
является выбранным if
оператором в функции if (is.null(model) || model == "KNN"){...
.
Где я здесь ошибаюсь?
Код и данные:
library(rlang)
data <- iris %>%
filter(Species != "setosa")
var1 = "Sepal.Length"
var2 = "Sepal.Width"
class = "Species"
resolution = 0.1
model = "KNN"
decisionplot <- function(model = NULL, data, var1, var2, class,
predict_type = "class", resolution = 0.1){
X_train_data = data[, c(eval_tidy(var1), eval_tidy(var2))]
Y_train_data <- data[, c(eval_tidy(class))]
XY_train_data <- cbind(Y_train_data, X_train_data) %>%
setNames(c("Y", "X1", "X2"))
grid <- expand.grid(
x = seq(
min(X_train_data[, 1] - 1),
max(X_train_data[, 1] + 1),
by = resolution
),
y = seq(
min(X_train_data[, 2] - 1),
max(X_train_data[, 2] + 1),
by = resolution
)
)
if (is.null(model) || model == "KNN"){
mini_model <- class::knn(X_train_data, grid, Y_train_data, k = 2, prob = TRUE)
mini_model_probs <- attr(mini_model, "prob")
data_plot <- bind_rows(
mutate(
grid,
prob = mini_model_probs,
class = "Non-Bankrupt",
prob_class = ifelse(
mini_model == 0, 1, 0
)
),
mutate(
grid,
prob = mini_model_probs,
class = "Bankrupt",
prob_class = ifelse(
mini_model == 0, 1, 0
)
)
)
#return(list(data_plot, mini_model, mini_model_probs))
}
if (is.null(model) || model == "Logistic"){
mini_model <- glm(Y ~ X1 + X2,
data = XY_train_data, family = "binomial")
mini_model_probs <- predict(object = mini_model, newdata = grid %>%
setNames(c("X1", "X2")), type = 'response')
data_plot <- bind_rows(
mutate(
grid,
prob = mini_model_probs,
class = "Non-Bankrupt",
prob_class = ifelse(
mini_model_probs >= 0.5, 1, 0
)
),
mutate(
grid,
prob = mini_model_probs,
class = "Bankrupt",
prob_class = ifelse(
mini_model_probs >= 0.5, 1, 0
)
)
)
#return(list(data_plot, mini_model, mini_model_probs))
}
else{
return("no model selected")
}
return(list(
X_train_data,
Y_train_data,
mini_model,
mini_model_probs,
data_plot,
grid,
XY_train_data))
}
mydat <- decisionplot(model = "Logistic",
data = data,
var1 = var1,
var2 = var2,
class = class)
mydat
###
ggplot() +
geom_point(aes(x = x, y = y, colour = class, size = prob_class),
data = mydat[[5]]) +
scale_size(range=c(0.2, 1)) +
geom_contour(aes(x = x, y = y, z = prob_class, group = factor(class), color = factor(class)),
bins = 2,
data = mydat[[5]]) +
geom_point(aes(x = X1, y = X2, color = factor(Y)),
size = 3,
alpha = 0.2,
data = mydat[[7]]) +
geom_point(aes(x = X1, y = X2),
size = 3,
shape = 1,
alpha = 0.2,
data = mydat[[7]])