Прогнозы случайных лесов дают одинаковое значение для каждого наблюдения в тестовом наборе, но разные значения для наблюдений в тренировочном наборе - PullRequest
1 голос
/ 14 мая 2019

Я пытался понять команду causal_forest в пакете grf в R. Когда я обучал модель с использованием следующего кода, я получал один и тот же прогноз для всех тестовых наблюдений, тогда как когда я делал прогнозы для обучающего набора, прогнозируемые значения были разными для каждогонаблюдение. Важность каждой переменной была показана равной 0. Я также проверил несколько случайных деревьев с помощью команды get_tree (), и они показали один лист, имеющий 12 наблюдений. Кроме того, прогнозируемые значения изменяются при увеличении n или мин.node.size установлен в ноль.

Мои вопросы: учитывая, что каждая переменная имеет значение 0, а деревья также имеют один лист (что я предполагаю, потому что проверяю множество случайных индексов и не могупроверьте все индексы) как для каждого наблюдения в обучающем наборе прогнозируемые значения различны, но все они одинаковы при прогнозировании с помощью тестового набора.Также есть ли способ визуализации того, как делаются прогнозы для обучения и тестирования.

library(grf)
#Train a causal forest.
n = 50; p = 10 
#n = 500
#genereate Normal random numbers with n rows and p columns 
X = matrix(rnorm(n*p), n, p)
#randomly assign treatment
W = rbinom(n, 1, 0.5)
#Simulated actual outcome 
Y = pmax(X[,1], 0) * W + X[,2] + pmin(X[,3], 0) + rnorm(n)
#run causal forest to predict Y by X and treatment W
c.forest = causal_forest(X, Y, W)
#c.forest = causal_forest(X, Y, W,min.node.size = 0)
c.forest
#plot tree at index 1
tree_1 <- grf::get_tree(c.forest,400)
plot(tree_1)
# Making test dataset
X.test = matrix(0, 101, p)
X.test[,1] = seq(-2, 2, length.out = 101)
# Predict using the forest.
c.pred = predict(c.forest,X.test)
c.pred
# Predict on out-of-bag training samples.
c.pred = predict(c.forest)
c.pred



> c.pred = predict(c.forest,X.test)
> head(c.pred,10)
    predictions
1   0.05819162
2   0.05819162
3   0.05819162
4   0.05819162
5   0.05819162
6   0.05819162
7   0.05819162
8   0.05819162
9   0.05819162
10  0.05819162

> c.pred = predict(c.forest)
> head(c.pred,10)
    predictions debiased.error
1  -0.110533258       7.859473
2  -0.048431382       7.402908
3   0.083399806       0.430714
4   0.007680637       1.557309
5  -0.010592712       3.331539
6   0.183042643       5.904475
7  -0.004142610       1.284445
8   0.092520160       1.734730
9  -0.036709332       2.479288
10 -0.019135867       1.230150
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...