Что означает компонент err.rate класса randomForest? - PullRequest
2 голосов
/ 05 марта 2020

Я использую функцию randomForest из пакета randomForest. Одним из объектов класса randomForest является err.rate, который представляет собой

(только классификация) векторных ошибок при прогнозировании на входных данных, i-й элемент является частотой ошибок (OOB) для все деревья вплоть до i-го.

Не могли бы вы объяснить, что означает этот компонент? Большое спасибо за вашу помощь!

В качестве примера кода я использую набор данных Sonar, Mines vs. Rocks.

library(mlbench)
data(Sonar)
library(boot)
library(randomForest)

n <- 208
ntrain <- 100
ntest <- 108
train.idx <- sample(1:n, ntrain, replace = FALSE)
train.set <- Sonar[train.idx, ]
test.set <- Sonar[-train.idx, ]

rf <- randomForest(Class ~ ., data = train.set, keep.inbag = TRUE, importance = TRUE)
head(rf$err.rate)

Вот результат кода

             OOB         M         R
  [1,] 0.1891892 0.1500000 0.2352941
  [2,] 0.2931034 0.2307692 0.3437500
  [3,] 0.2739726 0.2647059 0.2820513
  [4,] 0.2911392 0.2894737 0.2926829
  [5,] 0.2413793 0.2682927 0.2173913
  [6,] 0.2555556 0.2142857 0.2916667
  [7,] 0.2553191 0.2444444 0.2653061
  [8,] 0.2268041 0.1956522 0.2549020
  [9,] 0.2783505 0.2608696 0.2941176

1 Ответ

1 голос
/ 07 марта 2020

Одним из компонентов randomForest является пакетирование, в котором вы получаете консенсус-прогноз по количеству деревьев.

При увеличении количества деревьев ошибка OOB вычисляется на каждом шаге. Ошибка OOB не рассчитывается из сравнения прогноза, полученного из 1 дерева, с выборками OOB относительно этого дерева, но вместо этого вы используете усредненный прогноз по деревьям, из которых эта выборка не используется. Я рекомендую проверить это для обзора .

Итак, в вашем примере мы можем визуализировать это:

library(ggplot2)
library(tidyr)

plotdf <- pivot_longer(data.frame(ntrees=1:nrow(rf$err.rate),rf$err.rate),-ntrees)
ggplot(plotdf,aes(x=ntrees,y=value,col=name)) + 
geom_line() + theme_bw()

enter image description here

M и R - строки для ошибки в прогнозе для этой указанной c метки, а OOB (ваш первый столбец) - просто среднее из двух. По мере увеличения числа деревьев ваша ошибка OOB становится меньше, потому что вы получаете лучший прогноз по большему количеству деревьев.

Хорошая особенность randomForest в том, что вам не нужна перекрестная проверка, поскольку оценка OOB обычно довольно показательно. Ниже мы можем попытаться показать, что получаем тот же результат:

set.seed(12)
# split in 5 parts
trn = split(1:nrow(Sonar),sample(1:nrow(Sonar) %% 5))
sim = vector("list",5)
# the number of trees we incrementally grow
ntrees = c(1,20*(1:50)+1)

for(CV in 1:5){
idx = trn[[CV]]
train.set <- Sonar[-idx, ]
test.set <- Sonar[idx, ]
# first forest, n=1, but works
mdl <- randomForest(Class ~ ., data = train.set, ntree=1,
keep.inbag = TRUE, importance = TRUE,keep.forest=TRUE)
err_rate <- vector("numeric",51)
err_rate[1] <- mean(predict(mdl,test.set)!=test.set$Class)
#growing the tree
for(i in 1:50){
  mdl <- grow(mdl,10)
  err_rate[i+1] <- mean(predict(mdl,test.set)!=test.set$Class)
}
sim[[CV]] <- data.frame(ntrees=ntrees,err_rate=err_rate,CV=CV)
}
sim = do.call(rbind,sim)

#plot

ggplot(sim,aes(x=ntrees,y=err_rate)) + geom_line(aes(group=CV),alpha=0.2) + 
stat_summary(fun.y=mean,geom="line",col="blue")+theme_bw()

enter image description here

...