извлечь содержимое из confusionMatrix, сохраненного в столбце списка в dplyr - PullRequest
0 голосов
/ 27 сентября 2018

Как показано в коде ниже, после перекрестной проверки я пытаюсь извлечь метрики модели для каждого сгиба.Я сохранил все прогнозы при повторной выборке, сгруппировал данные по сгибам, вычислил матрицу путаницы для каждой группы и сохранил объект матрицы путаницы в виде столбца списка cm.Теперь мне нужно извлечь метрическую информацию, например, точность и т. Д. Из объектов, сохраненных в столбце.Мой пример кода показан ниже.

library(caret)
iris2 = iris %>% 
    filter(Species != 'setosa') %>%
    mutate(Species = factor(Species))

train.control <- trainControl(method="cv", 
                           number=5,
                           summaryFunction = twoClassSummary,
                           classProbs = TRUE,
                           savePredictions='all')
rf = train(Species~., data=iris2,  method = 'rf',
           metric = 'ROC', trControl=train.control)
rf$pred %>% group_by(Resample) %>%
    do(cm = confusionMatrix(.$pred, .$obs),
       Accuracy = map(cm, ~.x$byClass['Precision'])) 

Я получил сообщение об ошибке:

Error in .x$byClass : $ operator is invalid for atomic vectors

Я не мог понять, почему это не работает.У меня вопрос, как я могу изменить последнюю строку, чтобы она работала?Спасибо

1 Ответ

0 голосов
/ 27 сентября 2018

Вы можете использовать ungroup(), а затем просто mutate Accuracy, получая доступ к определенной части list для каждой складки, которую вы используете unlist() для извлечения самого элемента.

rf$pred %>% 
  group_by(Resample) %>%
  do(cm = confusionMatrix(.$pred, .$obs)) %>% 
  ungroup() %>% 
  mutate(neg_pred_value = map(cm, ~ .x[["byClass"]][["Neg Pred Value"]]) %>% unlist(),
         accuracy = map(cm, ~ .x[["byClass"]][["Precision"]]) %>% unlist())

Используя приведенный выше код, мы получаем следующий вывод в виде tibble

# A tibble: 5 x 4
  Resample                    cm neg_pred_value  accuracy
     <chr>                <list>          <dbl>     <dbl>
1    Fold1 <S3: confusionMatrix>      0.9090909 1.0000000
2    Fold2 <S3: confusionMatrix>      1.0000000 1.0000000
3    Fold3 <S3: confusionMatrix>      1.0000000 1.0000000
4    Fold4 <S3: confusionMatrix>      0.8181818 0.8888889
5    Fold5 <S3: confusionMatrix>      1.0000000 0.9090909
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...