R: как извлечь базовые прогнозы для Forex.xgboost? - PullRequest
0 голосов
/ 14 апреля 2020
library(xgboost)
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')

# Initialize baseline predictions to be 1.5
baseline_predictions <- rep(1.5, nrow(agaricus.train$data))

# base_margin is the base prediction Xgboost will boost from ;
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label, base_margin = baseline_predictions)
dtest <- xgb.DMatrix(agaricus.test$data)

param <- list(max_depth = 2, eta = 1, verbose = 0, nthread = 2,
              objective = "binary:logistic", eval_metric = "auc")

# Train model
bst <- xgb.train(param, dtrain, nrounds = 2)

#Predict on test set
predict(bst, newdata = dtest)

В приведенном выше коде я обучил модель xgboost с именем bst, которая была инициализирована с baseline_predictions. Затем я использовал функцию predict для подгонки модели к тестовому набору dtest.

Мой вопрос: как я могу определить, откуда была повышена модель в predict(bst, newdata = dtest)? Я понимаю, что могу использовать следующий код для извлечения базовых значений, из которых будет увеличено значение bst:

xgboost::getinfo(object = dtrain, name = "base_margin")

Но поскольку я не указал base_margin в xgb.DMatrix(agaricus.test$data), выполнение следующего кода возвращает NULL

xgboost::getinfo(object = dtest, name = "base_margin")

Так что predict(bst, newdata = dtest) усиливается с baseline_predictions (1,5 для всех наблюдений) или это повышение с чего-то еще?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...