Применение функции объяснения DALEX к модели xgboost для анализа what_if / centerisParibus - PullRequest
0 голосов
/ 26 мая 2020

Мне трудно применить анализ what_if к модели xgboost. Я могу выполнить анализ what_if для модели randomForest, однако он ломается, когда я пытаюсь запустить его для модели xgboost.

Мой вопрос, учитывая набор данных titanic, как можно Я делаю сюжет what_if? Я добавил комментарии к коду, чтобы показать, когда код ломается для меня.

Я знаю, что делаю что-то неправильно с частью new_xgb_observation, но what_if (насколько мне известно) ожидает единственное наблюдение, поэтому я пытаюсь извлечь из матрицы dtest одно наблюдение.

Это часть кода, которая у меня ломается:

#### #### #### #### #### #### ####
# new observation -  which breaks
new_xgb_observation <- dtest[1, ]

# ceteris paribus - what_if analysis which breaks
what_if(xgb_explain, observation = new_xgb_observation,
        selected_variables = c("gender", "age", "fare", "sibsp"))
#### #### #### #### #### #### ####

Затем я показываю рабочая randomForest модель ниже него.

Данные:

library(DALEX)
library(ceterisParibus)
library(xgboost)

data("titanic")
data <- titanic

# some quick data cleaning
data <- data %>% 
  select(-c(class, embarked, country)) %>% 
  mutate(
    gender = as.numeric(gender) - 1,
    survived = as.numeric(survived) -1
  )

# split into training and testing data
smp_size <- floor(0.75 * nrow(data))
train_ind <- sample(seq_len(nrow(data)), size = smp_size)

train <- data[train_ind, ]
test <- data[-train_ind, ]

X_train <- train %>% 
  select(-c(survived)) %>% 
  as.matrix()

Y_train <- train %>% 
  select(c(survived)) %>% 
  as.matrix()

X_test <- test %>% 
  select(-c(survived)) %>% 
  as.matrix()

Y_test <- test %>% 
  select(c(survived)) %>% 
  as.matrix()

# train and test as xgb.DMatrix for the XGBoost model
dtrain <- xgb.DMatrix(data = X_train, label = Y_train)
dtest <- xgb.DMatrix(data = X_test, label = Y_test)

# XGBoost parameters
params <- list(
  "eta" = 0.2,
  "max_depth" = 6,
  "objective"="binary:logistic",
  "eval_metric"= "auc",
  "set.seed" = 176
)

# run the XGBoost model
watchlist <- list("train" = dtrain)
nround = 40
xgb.model <- xgb.train(params, dtrain, nround, watchlist)

# DALEX model explanation
xgb_explain <- explain(xgb.model, data = X_train, label = Y_train)

#### #### #### #### #### #### ####
# new observation -  which breaks
new_xgb_observation <- dtest[1, ]

# ceteris paribus - what_if analysis which breaks
what_if(xgb_explain, observation = new_xgb_observation,
        selected_variables = c("gender", "age", "fare", "sibsp"))
#### #### #### #### #### #### ####

################## random Forest model #################

Random_Forest_Model <- randomForest::randomForest(factor(survived) ~., data = train, na.action = na.omit, ntree = 50, importance = TRUE)

# same as for the XGBoost model but this time remove na values
train_rf <- na.omit(train)
X_train_rf <- train_rf %>% 
  select(-c(survived))
Y_train_rf <- train_rf %>% 
  select(c(survived))

test_rf <- na.omit(test)
X_test_rf <- test_rf %>% 
  select(-c(survived))
Y_test_rf <- test_rf %>% 
  select(c(survived))

# DALEX model explanation
rf_explain <- explain(Random_Forest_Model, 
                      data = X_train_rf,
                      y = Y_train_rf)

# This time this works.
new_obs <- X_test_rf[1, ]

# So does this
wi_rf_model <- what_if(rf_explain, observation = new_obs,
                       selected_variables = c("gender", "age", "fare", "sibsp"))

# And this is what I ultimately want.
plot(wi_rf_model, split = "variables", color = "variables", quantiles = FALSE)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...