nested_cv из rsamples - не уверены в результатах с более чем 2 группами - PullRequest
2 голосов
/ 18 октября 2019

Я запускаю следующее:

data(sunspots)
data <- sunspots %>% 
  as_tibble() %>% 
  mutate(category = sample(c('male', 'female'), nrow(.), replace=TRUE))

data

Что дает мне:

# A tibble: 2,820 x 2
       x category
   <dbl> <chr>   
 1  58   male    
 2  62.6 male    
 3  70   male    
 4  55.7 male    
 5  85   male    
 6  83.5 female  
 7  94.8 female  
 8  66.3 female  
 9  75.9 male    
10  75.5 female  
# ... with 2,810 more rows

Что является объектом временного ряда. То, что я хочу сделать, это group_by столбец male / female и применить rolling_origin из пакета rsample, где он составляет 100 учебных месяцев и 1 тестовый месяц (для каждой категории).

Я делаю следующее:

library(rsample)
periods_train <- 100
periods_test  <- 1
skip_span     <- 0


cv_rolling <- nested_cv(data, 
                        outside = group_vfold_cv(group = "category"),
                        inside = rolling_origin(
                          initial    = periods_train,
                          assess     = periods_test,
                          cumulative = FALSE,
                          skip       = skip_span))

Что дает мне:

[1] "nested_cv"      "group_vfold_cv" "rset"           "tbl_df"         "tbl"            "data.frame"    
# Nested resampling:
#  outer: Group -fold cross-validation
#  inner: Rolling origin forecast resampling
# A tibble: 2 x 3
  splits              id        inner_resamples     
  <named list>        <chr>     <named list>        
1 <split [1.4K/1.4K]> Resample1 <tibble [1,311 x 2]>
2 <split [1.4K/1.4K]> Resample2 <tibble [1,309 x 2]>

Я запускаю следующее:

> cv_rolling$inner_resamples
$`1`
# Rolling origin forecast resampling 
# A tibble: 1,311 x 2
   splits          id       
   <list>          <chr>    
 1 <split [100/1]> Slice0001
 2 <split [100/1]> Slice0002
 3 <split [100/1]> Slice0003
 4 <split [100/1]> Slice0004
 5 <split [100/1]> Slice0005
 6 <split [100/1]> Slice0006
 7 <split [100/1]> Slice0007
 8 <split [100/1]> Slice0008
 9 <split [100/1]> Slice0009
10 <split [100/1]> Slice0010
# ... with 1,301 more rows

$`2`
# Rolling origin forecast resampling 
# A tibble: 1,309 x 2
   splits          id       
   <list>          <chr>    
 1 <split [100/1]> Slice0001
 2 <split [100/1]> Slice0002
 3 <split [100/1]> Slice0003
 4 <split [100/1]> Slice0004
 5 <split [100/1]> Slice0005
 6 <split [100/1]> Slice0006
 7 <split [100/1]> Slice0007
 8 <split [100/1]> Slice0008
 9 <split [100/1]> Slice0009
10 <split [100/1]> Slice0010
# ... with 1,299 more rows

И все выглядит правильно, sunspots данные имеют 2820 наблюдений и расщепления 1311 + 1309 = 2620. (-200 из данных солнечных пятен, так как я установил periods_train <- 100 для обеих категорий.

Running:

map(cv_rolling$inner_resamples$`1`$splits, ~ analysis(.x)) %>% 
  head()

map(cv_rolling$inner_resamples$`1`$splits, ~ analysis(.x)) %>% 
  tail()

Кажется, чтобы дать мне правильный вывод также для "мужской" (замена 1 с 2 дает «женский» вывод.) В качестве случайной проверки я также запускаю следующее:

Первый код дает мне результат, второй нет (что и следовало ожидать, так как яЯ только анализирую мужской 1 сплит.)

cv_rolling$inner_resamples$`1`$splits[[650]] %>% 
  analysis() %>% 
  filter(category == "male")

cv_rolling$inner_resamples$`1`$splits[[650]] %>% 
  analysis() %>% 
  filter(category == "female")

Проблема

Я возвращаюсь к началу и добавляю в категорию other

data(sunspots)
data <- sunspots %>% 
  as_tibble() %>% 
  mutate(category = sample(c('male', 'female', 'other'), nrow(.), replace=TRUE))

Там, где я выполняю тот же анализ:

> data
# A tibble: 2,820 x 2
       x category
   <dbl> <chr>   
 1  58   other   
 2  62.6 male    
 3  70   female  
 4  55.7 male    
 5  85   male    
 6  83.5 male    
 7  94.8 other   
 8  66.3 other   
 9  75.9 female  
10  75.5 male    
# ... with 2,810 more rows

Я запускаю точно так же, как и раньше:

library(rsample)
periods_train <- 100
periods_test  <- 1
skip_span     <- 0


cv_rolling <- nested_cv(data, 
                        outside = group_vfold_cv(group = "category"),
                        inside = rolling_origin(
                          initial    = periods_train,
                          assess     = periods_test,
                          cumulative = FALSE,
                          skip       = skip_span))

cv_rolling

Теперь у меня есть 3 ожидаемых выхода.

[1] "nested_cv"      "group_vfold_cv" "rset"           "tbl_df"         "tbl"            "data.frame"    
# Nested resampling:
#  outer: Group -fold cross-validation
#  inner: Rolling origin forecast resampling
# A tibble: 3 x 3
  splits             id        inner_resamples     
  <named list>       <chr>     <named list>        
1 <split [1.9K/961]> Resample1 <tibble [1,759 x 2]>
2 <split [1.9K/939]> Resample2 <tibble [1,781 x 2]>
3 <split [1.9K/920]> Resample3 <tibble [1,800 x 2]>

Запуск:

> cv_rolling$inner_resamples
$`1`
# Rolling origin forecast resampling 
# A tibble: 1,759 x 2
   splits          id       
   <list>          <chr>    
 1 <split [100/1]> Slice0001
 2 <split [100/1]> Slice0002
 3 <split [100/1]> Slice0003
 4 <split [100/1]> Slice0004
 5 <split [100/1]> Slice0005
 6 <split [100/1]> Slice0006
 7 <split [100/1]> Slice0007
 8 <split [100/1]> Slice0008
 9 <split [100/1]> Slice0009
10 <split [100/1]> Slice0010
# ... with 1,749 more rows

$`2`
# Rolling origin forecast resampling 
# A tibble: 1,781 x 2
   splits          id       
   <list>          <chr>    
 1 <split [100/1]> Slice0001
 2 <split [100/1]> Slice0002
 3 <split [100/1]> Slice0003
 4 <split [100/1]> Slice0004
 5 <split [100/1]> Slice0005
 6 <split [100/1]> Slice0006
 7 <split [100/1]> Slice0007
 8 <split [100/1]> Slice0008
 9 <split [100/1]> Slice0009
10 <split [100/1]> Slice0010
# ... with 1,771 more rows

$`3`
# Rolling origin forecast resampling 
# A tibble: 1,800 x 2
   splits          id       
   <list>          <chr>    
 1 <split [100/1]> Slice0001
 2 <split [100/1]> Slice0002
 3 <split [100/1]> Slice0003
 4 <split [100/1]> Slice0004
 5 <split [100/1]> Slice0005
 6 <split [100/1]> Slice0006
 7 <split [100/1]> Slice0007
 8 <split [100/1]> Slice0008
 9 <split [100/1]> Slice0009
10 <split [100/1]> Slice0010
# ... with 1,790 more rows

Который не имеет смысла ... Поскольку 1759 + 1781 + 1800 = 5340 и данные солнечных пятен имеют длину 2820.

Теперь у меня естьперекрывающиеся other и male наблюдения:

> map(cv_rolling$inner_resamples$`2`$splits, ~ analysis(.x)) %>% 
+   head()
[[1]]
# A tibble: 100 x 2
       x category
   <dbl> <chr>   
 1  58   other   
 2  62.6 male    
 3  55.7 male    
 4  85   male    
 5  83.5 male    
 6  94.8 other   
 7  66.3 other   
 8  75.5 male    
 9  85.2 male    
10  73.3 other   
# ... with 90 more rows

Наконец:

cv_rolling$inner_resamples$`1`$splits[[650]] %>% 
  analysis() %>% 
  filter(category == "male")

cv_rolling$inner_resamples$`1`$splits[[650]] %>% 
  analysis() %>% 
  filter(category == "other")

cv_rolling$inner_resamples$`1`$splits[[650]] %>% 
  analysis() %>% 
  filter(category == "female")

, который возвращает результаты для other и female, но не male.

Что я сделалздесь не так? Что если я хочу увеличить количество категорий, как я могу выполнить rolling_origin для каждой из категорий?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...