Невозможно запустить настройку параметров для модели регрессии XGBoost с помощью каретки - PullRequest
0 голосов
/ 23 ноября 2018

Я пытаюсь построить регрессионную модель, используя данные Boston Housing, используя пакет caret.Код выглядит следующим образом:

library(tidyverse)
library(ggplot2)
library(lubridate)
library(broom)
library(caret)
library(xgboost)

#list.files()

options(scipen = 999)

library(MASS)

data_model <- Boston
data_model <- as.data.frame(data_model)

# based on this link https://stackoverflow.com/questions/51762536/r-xgboost-on-caret-attempts-to-perform-classification-instead-of-regression
data_model$medv <- as.double(data_model$medv)
data_model$zn <- as.double(data_model$zn)
xgb_grid_1 = expand.grid(
  nrounds = 1000,
  max_depth = c(2, 4, 6, 8, 10),
  eta=c(0.5, 0.1, 0.07),
  gamma = 0.01,
  colsample_bytree=0.5,
  min_child_weight=1,
  subsample=0.5
)

xgb_trcontrol_1 = trainControl(
  method = "cv",
  number = 5,
  allowParallel = TRUE
)


xgb_train_1 = train(
  x = data_model %>% dplyr::select(-medv) %>% as.matrix(),
  y = as.matrix(data_model$medv),
  trControl = xgb_trcontrol_1,
  tuneGrid = xgb_grid_1,
  method = "xgbTree",
  metric = 'RMSE'
)

sessionInfo()

Но когда я запускаю функцию train(), я получаю ошибку Error: Metric RMSE not applicable for classification models.Затем я попытался изменить переменные, которые были integers на double, как предложено этой ссылкой .Я все еще, кажется, получаю ту же ошибку.Я пропускаю дополнительный параметр, который должен позаботиться об этом?Заранее спасибо!Я также включил информацию о моем сеансе ниже на случай конфликта версий, о котором мне неизвестно.

R version 3.4.0 (2017-04-21)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS  10.14

Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.4/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] MASS_7.3-47        bindrcpp_0.2.2     xgboost_0.71.2     caret_6.0-81       lattice_0.20-35    broom_0.4.2        lubridate_1.6.0    dplyr_0.7.8        purrr_0.2.3       
[10] readr_1.1.1        tidyr_0.7.2        tibble_1.4.2       ggplot2_2.2.1.9000 tidyverse_1.1.1   

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.0           class_7.3-14         utf8_1.1.3           assertthat_0.2.0     ipred_0.9-6          psych_1.7.5          foreach_1.4.3        R6_2.2.2            
 [9] cellranger_1.1.0     plyr_1.8.4           stats4_3.4.0         httr_1.3.1           pillar_1.2.1         rlang_0.3.0.1        lazyeval_0.2.1       readxl_1.0.0        
[17] rstudioapi_0.7       data.table_1.10.4    rpart_4.1-11         Matrix_1.2-9         splines_3.4.0        gower_0.1.2          stringr_1.3.0        foreign_0.8-67      
[25] munsell_0.4.3        compiler_3.4.0       modelr_0.1.1         pkgconfig_2.0.1      mnormt_1.5-5         nnet_7.3-12          tidyselect_0.2.5     prodlim_2018.04.18  
[33] codetools_0.2-15     crayon_1.3.4         withr_2.1.2          recipes_0.1.4        ModelMetrics_1.1.0   grid_3.4.0           nlme_3.1-131         jsonlite_1.5        
[41] gtable_0.2.0         magrittr_1.5         waterfalls_0.1.2     scales_0.5.0.9000    cli_1.0.0            stringi_1.1.7        reshape2_1.4.3       timeDate_3012.100   
[49] xml2_1.2.0           generics_0.0.1       lava_1.6.1           iterators_1.0.8      tools_3.4.0          forcats_0.2.0        glue_1.3.0           hms_0.3             
[57] parallel_3.4.0       survival_2.41-3      colorspace_1.3-2     xgboostExplainer_0.1 rvest_0.3.2          bindr_0.1.1          haven_1.1.0  

1 Ответ

0 голосов
/ 23 ноября 2018

Вы уже конвертировали data_model$zn в double.Итак, просто удалите as.matrix в параметре y = as.matrix(data_model$medv)

...