Р, имл, млр. Важность функции всегда возвращает 1 для каждой функции - PullRequest
1 голос
/ 23 января 2020

Я что-то делаю с фреймворком mlr, который заставляет FeatureImp возвращать 1 для каждой функции, и я не могу это понять. Вот пример:

library(caret)
#> Carregando pacotes exigidos: lattice
#> Carregando pacotes exigidos: ggplot2
library(mlr)
#> Carregando pacotes exigidos: ParamHelpers
#> 
#> Attaching package: 'mlr'
#> The following object is masked from 'package:caret':
#> 
#>     train
library(iml)

data("iris")
iris = iris[iris$Species != 'setosa',]
iris$Species = ifelse(iris$Species == 'virginica', 1, 0)
iris$Species = as.factor(iris$Species)

ind=createDataPartition(iris$Species, times=1, p=0.8, list=FALSE)
train=iris[ind,]
test=iris[-ind,]
remove(ind)

train.task=makeClassifTask(data=train, target = 'Species', positive = 1)
test.task=makeClassifTask(data=test, target = 'Species', positive = 1)

learner=list(
  xgboost = makeLearner("classif.xgboost",predict.type = "prob"),
  ksvm = makeLearner("classif.ksvm",predict.type = "prob"),
  nnet = makeLearner("classif.nnet",predict.type = "prob"),
  randomForest = makeLearner("classif.randomForest",predict.type = "prob")
)

model = lapply(learner, function(x) train(x, train.task))
#> # weights:  19
#> initial  value 57.506055 
#> iter  10 value 52.109027
#> iter  20 value 7.798098
#> iter  30 value 5.401193
#> iter  40 value 4.707935
#> iter  50 value 4.702049
#> final  value 4.701710 
#> converged
prediction = lapply(model, function(x) predict(x, test.task))

ensemble = makeStackedLearner(learner, super.learner = 'classif.randomForest', predict.type = 'prob',
                              method = "stack.cv", use.feat = FALSE)
model$ensemble = train(ensemble, train.task)
#> # weights:  19
#> initial  value 43.712841 
#> iter  10 value 5.444287
#> iter  20 value 4.536990
#> iter  30 value 4.527489
#> iter  40 value 4.481401
#> iter  50 value 4.481221
#> iter  50 value 4.481221
#> iter  50 value 4.481221
#> final  value 4.481221 
#> converged
#> # weights:  19
#> initial  value 52.864011 
#> iter  10 value 33.347827
#> iter  20 value 2.926847
#> iter  30 value 0.011104
#> final  value 0.000055 
#> converged
#> # weights:  19
#> initial  value 44.627604 
#> iter  10 value 31.360597
#> iter  20 value 5.798769
#> iter  30 value 4.290623
#> iter  40 value 3.751202
#> iter  50 value 3.547856
#> iter  60 value 3.469366
#> iter  70 value 3.373487
#> iter  80 value 3.317680
#> iter  90 value 3.310354
#> iter 100 value 3.301115
#> final  value 3.301115 
#> stopped after 100 iterations
#> # weights:  19
#> initial  value 46.410266 
#> iter  10 value 29.975896
#> iter  20 value 1.266423
#> iter  30 value 0.004667
#> final  value 0.000052 
#> converged
#> # weights:  19
#> initial  value 52.665930 
#> final  value 44.361399 
#> converged
#> # weights:  19
#> initial  value 60.471973 
#> iter  10 value 50.475349
#> iter  20 value 7.580138
#> iter  30 value 4.828646
#> iter  40 value 4.543112
#> iter  50 value 2.995374
#> iter  60 value 2.636710
#> iter  70 value 2.539857
#> iter  80 value 2.497281
#> iter  90 value 2.427158
#> iter 100 value 2.370383
#> final  value 2.370383 
#> stopped after 100 iterations
prediction$ensemble = predict(model$ensemble, test.task)

predictor = Predictor$new(model$ensemble,
                          data = train.task$env$data[which(names(train.task$env$data) != "Species")],
                          y = as.numeric(train.task$env$data$Species)-1)

imp = FeatureImp$new(predictor, loss = "ce")
imp$results
#>        feature importance.05 importance importance.95 permutation.error
#> 1 Sepal.Length             1          1             1                 1
#> 2  Sepal.Width             1          1             1                 1
#> 3 Petal.Length             1          1             1                 1
#> 4  Petal.Width             1          1             1                 1

Создано в 2020-01-23 с помощью представительного пакета (v0.3.0)

1 Ответ

1 голос
/ 23 января 2020

Похоже, это исправлено в версии dev {iml}.

Я могу воспроизвести ваши проблемы с текущей версией CRAN.

library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
library(mlr)
#> Loading required package: ParamHelpers
#> 'mlr' is in maintenance mode since July 2019. Future development
#> efforts will go into its successor 'mlr3' (<https://mlr3.mlr-org.com>).
#> 
#> Attaching package: 'mlr'
#> The following object is masked from 'package:caret':
#> 
#>     train
library(iml)

data("iris")
iris = iris[iris$Species != "setosa", ]
iris$Species = ifelse(iris$Species == "virginica", 1, 0)
iris$Species = as.factor(iris$Species)

ind = createDataPartition(iris$Species, times = 1, p = 0.8, list = FALSE)
train = iris[ind, ]
test = iris[-ind, ]
remove(ind)

train.task = makeClassifTask(data = train, target = "Species", positive = 1)
test.task = makeClassifTask(data = test, target = "Species", positive = 1)

learner = list(
  xgboost = makeLearner("classif.xgboost", predict.type = "prob"),
  ksvm = makeLearner("classif.ksvm", predict.type = "prob"),
  nnet = makeLearner("classif.nnet", predict.type = "prob"),
  randomForest = makeLearner("classif.randomForest", predict.type = "prob")
)

model = lapply(learner, function(x) train(x, train.task))
#> # weights:  19
#> initial  value 59.040647 
#> iter  10 value 54.908003
#> iter  20 value 8.784817
#> iter  30 value 2.906017
#> iter  40 value 0.187334
#> iter  50 value 0.000610
#> final  value 0.000059 
#> converged
prediction = lapply(model, function(x) predict(x, test.task))

ensemble = makeStackedLearner(learner,
  super.learner = "classif.randomForest", predict.type = "prob",
  method = "stack.cv", use.feat = FALSE)
model$ensemble = train(ensemble, train.task)
#> # weights:  19
#> initial  value 44.537254 
#> iter  10 value 6.716784
#> iter  20 value 4.750452
#> iter  30 value 4.487501
#> iter  40 value 4.481250
#> final  value 4.481222 
#> converged
#> # weights:  19
#> initial  value 54.135701 
#> iter  10 value 13.081961
#> iter  20 value 1.676063
#> iter  30 value 0.002261
#> final  value 0.000044 
#> converged
#> # weights:  19
#> initial  value 42.621635 
#> iter  10 value 5.201573
#> iter  20 value 2.878946
#> iter  30 value 1.133911
#> iter  40 value 0.002784
#> iter  50 value 0.000726
#> final  value 0.000037 
#> converged
#> # weights:  19
#> initial  value 43.795663 
#> iter  10 value 4.478310
#> iter  20 value 1.811306
#> iter  30 value 0.027775
#> iter  40 value 0.004873
#> iter  50 value 0.001480
#> iter  60 value 0.000230
#> iter  70 value 0.000221
#> final  value 0.000089 
#> converged
#> # weights:  19
#> initial  value 44.433321 
#> iter  10 value 7.252874
#> iter  20 value 1.200457
#> iter  30 value 0.001668
#> final  value 0.000063 
#> converged
#> # weights:  19
#> initial  value 67.012204 
#> final  value 55.451774 
#> converged
prediction$ensemble = predict(model$ensemble, test.task)

predictor = Predictor$new(model$ensemble,
  data = train.task$env$data[which(names(train.task$env$data) != "Species")],
  y = as.numeric(train.task$env$data$Species) - 1)

imp = FeatureImp$new(predictor, loss = "ce")
imp$results
#>        feature importance.05 importance importance.95 permutation.error
#> 1  Petal.Width          11.1       12.0          14.2            0.3000
#> 2 Petal.Length          10.3       11.5          13.1            0.2875
#> 3 Sepal.Length           3.3        4.5           6.3            0.1125
#> 4  Sepal.Width           2.1        3.5           4.0            0.0875

Создано в 2020- 01-23 представьте пакет (v0.3.0)

Информация о сеансе

devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                                      
#>  version  R version 3.6.2 Patched (2019-12-12 r77564)
#>  os       macOS Mojave 10.14.6                       
#>  system   x86_64, darwin15.6.0                       
#>  ui       X11                                        
#>  language (EN)                                       
#>  collate  en_US.UTF-8                                
#>  ctype    en_US.UTF-8                                
#>  tz       Europe/Berlin                              
#>  date     2020-01-23                                 
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version     date       lib
#>  assertthat     0.2.1       2019-03-21 [1]
#>  backports      1.1.5       2019-10-02 [1]
#>  BBmisc         1.11        2017-03-10 [1]
#>  callr          3.4.0       2019-12-09 [1]
#>  caret        * 6.0-85      2020-01-07 [1]
#>  checkmate      1.9.4       2019-07-04 [1]
#>  class          7.3-15      2019-01-01 [2]
#>  cli            2.0.1.9000  2020-01-12 [1]
#>  codetools      0.2-16      2018-12-24 [2]
#>  colorspace     1.4-1       2019-03-18 [1]
#>  crayon         1.3.4       2017-09-16 [1]
#>  data.table     1.12.8      2019-12-09 [1]
#>  desc           1.2.0       2018-05-01 [1]
#>  devtools       2.2.1       2019-09-24 [1]
#>  digest         0.6.23      2019-11-23 [1]
#>  dplyr          0.8.3       2019-07-04 [1]
#>  ellipsis       0.3.0       2019-09-20 [1]
#>  evaluate       0.14        2019-05-28 [1]
#>  fansi          0.4.1       2020-01-08 [1]
#>  fastmatch      1.1-0       2017-01-28 [1]
#>  foreach        1.4.7       2019-07-27 [1]
#>  fs             1.3.1       2019-05-06 [1]
#>  generics       0.0.2       2018-11-29 [1]
#>  ggplot2      * 3.2.1       2019-08-10 [1]
#>  glue           1.3.1       2019-03-12 [1]
#>  gower          0.2.1       2019-05-14 [1]
#>  gridExtra      2.3         2017-09-09 [1]
#>  gtable         0.3.0       2019-03-25 [1]
#>  highr          0.8         2019-03-20 [1]
#>  htmltools      0.4.0       2019-10-04 [1]
#>  iml          * 0.9.0       2020-01-23 [1]
#>  ipred          0.9-9       2019-04-28 [1]
#>  iterators      1.0.12      2019-07-26 [1]
#>  kernlab        0.9-29      2019-11-12 [1]
#>  knitr          1.27        2020-01-16 [1]
#>  lattice      * 0.20-38     2018-11-04 [2]
#>  lava           1.6.6       2019-08-01 [1]
#>  lazyeval       0.2.2       2019-03-15 [1]
#>  lifecycle      0.1.0       2019-08-01 [1]
#>  lubridate      1.7.4       2018-04-11 [1]
#>  magrittr       1.5         2014-11-22 [1]
#>  MASS           7.3-51.4    2019-03-31 [1]
#>  Matrix         1.2-18      2019-11-27 [2]
#>  memoise        1.1.0       2017-04-21 [1]
#>  Metrics        0.1.4       2018-07-09 [1]
#>  mlr          * 2.17.0.9000 2020-01-13 [1]
#>  ModelMetrics   1.2.2.1     2020-01-13 [1]
#>  munsell        0.5.0       2018-06-12 [1]
#>  nlme           3.1-143     2019-12-10 [2]
#>  nnet           7.3-12      2016-02-02 [2]
#>  parallelMap    1.4         2019-05-17 [1]
#>  ParamHelpers * 1.13.0.9000 2019-12-11 [1]
#>  pillar         1.4.3       2019-12-20 [1]
#>  pkgbuild       1.0.6       2019-10-09 [1]
#>  pkgconfig      2.0.3       2019-09-22 [1]
#>  pkgload        1.0.2       2018-10-29 [1]
#>  plyr           1.8.5       2019-12-10 [1]
#>  prediction     0.3.14      2019-06-17 [1]
#>  prettyunits    1.1.0       2020-01-09 [1]
#>  pROC           1.16.1      2020-01-14 [1]
#>  processx       3.4.1       2019-07-18 [1]
#>  prodlim        2019.11.13  2019-11-17 [1]
#>  ps             1.3.0       2018-12-21 [1]
#>  purrr          0.3.3       2019-10-18 [1]
#>  R6             2.4.1       2019-11-12 [1]
#>  randomForest   4.6-14      2018-03-25 [1]
#>  Rcpp           1.0.3       2019-11-08 [1]
#>  recipes        0.1.9       2020-01-07 [1]
#>  remotes        2.1.0       2019-06-24 [1]
#>  reshape2       1.4.3       2017-12-11 [1]
#>  rlang          0.4.3       2020-01-22 [1]
#>  rmarkdown      2.1         2020-01-20 [1]
#>  rpart          4.1-15      2019-04-12 [1]
#>  rprojroot      1.3-2       2018-01-03 [1]
#>  scales         1.1.0       2019-11-18 [1]
#>  sessioninfo    1.1.1       2018-11-05 [1]
#>  stringi        1.4.5       2020-01-11 [1]
#>  stringr        1.4.0       2019-02-10 [1]
#>  survival       3.1-8       2019-12-03 [2]
#>  testthat       2.3.1       2019-12-01 [1]
#>  tibble         2.1.3       2019-06-06 [1]
#>  tidyselect     0.2.5       2018-10-11 [1]
#>  timeDate       3043.102    2018-02-21 [1]
#>  usethis        1.5.1.9000  2020-01-17 [1]
#>  withr          2.1.2       2018-03-15 [1]
#>  xfun           0.12        2020-01-13 [1]
#>  xgboost        0.90.0.2    2019-08-01 [1]
#>  XML            3.99-0.3    2020-01-20 [1]
#>  yaml           2.2.0       2018-07-25 [1]
#>  source                                   
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  Github (r-lib/cli@f786d87)               
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.0)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  Github (christophM/iml@54b2ce2)          
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.0)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.0)                           
#>  local                                    
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.2)                           
#>  Github (berndbischl/ParamHelpers@c2d989c)
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.0)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.0)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  Github (r-lib/rlang@624c5c3)             
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  Github (pat-s/usethis@0251102)           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.0)                           
#> 
#> [1] /Users/pjs/Library/R/3.6/library
#> [2] /Library/Frameworks/R.framework/Versions/3.6/Resources/library

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...