При перекрестной проверке гиперпараметра эластичной сети lambda
с использованием опции lambda_search
алгоритм может не выбрать значение lambda
из указанной сетки, что минимизирует отклонение в проверочном образце. Это также происходит, когда мы устанавливаем early_stopping = FALSE
, то есть когда можно ожидать, что H2O оценит все значения lambda
в сетке.
Этот оператор может быть проверен путем перекрестной проверки лямбды, сначала используя lambda_search = TRUE
в h2o.glm()
, затем выполняя поиск в сетке по тем же значениям лямбды, используя h2o.grid()
, и сравнивая результирующие гиперпараметры и значения отклонения проверки. Смотрите код R ниже.
Эта проблема тесно связана с тем, что указано здесь и упоминается здесь . Этот вопрос добавляет документацию о том, что значение перекрестной проверки lambda
не обязательно должно быть тем, которое минимизирует отклонение проверки. То есть проблема может быть более серьезной, чем вычисление H2O, вплоть до наилучшей лямбды и последующего выхода, как указано в комментариях здесь . Проблема возникла при настройке одного образца проверки в Tweedie glm со ссылкой на журнал, я не уверен, насколько он специфичен для этого параметра.
На основании этих результатов я бы всегда использовал поиск по сетке для определения lambda
. Это уместно? В качестве альтернативы, есть ли в h2o.glm()
опция, которая решает проблему с lambda_search
?
rm(list = ls())
library(h2o)
library(tweedie)
library(tidyverse)
# Configuration -----------------------------------------------------------
# DGP:
n = 1000
k = 10
phi = 1
const = 0
bet = seq(-1, 1, length.out = k)
power = 1.5
# algorithm
alpha = 0.5
# Generate some data ------------------------------------------------------
set.seed(42)
x = rnorm(n * k) %>%
matrix(nrow = n, dimnames = list(NULL, paste0("x", seq(1, k))))
mu = as.numeric(exp(const + x %*% bet))
dat = x %>%
as_tibble() %>%
mutate(mu = mu,
y = rtweedie(n,
mu = mu,
phi = phi,
power = power),
id = row_number(),
sample = case_when(
id <= n / 2 ~ "train",
TRUE ~ "valid"))
# Initialize H2O ----------------------------------------------------------
h2o.init()
df_h2o_train = dat %>%
filter(sample == "train") %>%
as.h2o()
df_h2o_valid = dat %>%
filter(sample == "valid") %>%
as.h2o()
# Tune lambda -------------------------------------------------------------
# 1. Lambda search
glm_warmstart = h2o.glm(
x = paste0("x", seq(1, k)),
y = "y",
family = "tweedie",
tweedie_variance_power = power,
tweedie_link_power = 0,
training_frame = df_h2o_train,
validation_frame = df_h2o_valid,
alpha = alpha,
lambda_search = TRUE,
early_stopping = FALSE
)
lambda_warmstart = glm_warmstart@model$lambda_best
print(lambda_warmstart) # 0.1501327
# 2. Grid search
hyper_params = list(lambda = glm_warmstart@model$scoring_history$lambda %>%
h2o.asnumeric())
grid_search = h2o.grid("glm",
hyper_params = hyper_params,
x = paste0("x", seq(1, k)),
y = "y",
family = "tweedie",
tweedie_variance_power = power,
tweedie_link_power = 0,
training_frame = df_h2o_train,
validation_frame = df_h2o_valid,
alpha = alpha,
lambda_search = FALSE)
lambda_grid_search = grid_search@summary_table %>%
as_tibble() %>%
head(1) %>%
pull(lambda) %>%
stringr::str_sub(2, -2) %>%
as.numeric()
print(lambda_grid_search) # 0.013
glm_grid_search = h2o.glm(
x = paste0("x", seq(1, k)),
y = "y",
family = "tweedie",
tweedie_variance_power = power,
tweedie_link_power = 0,
training_frame = df_h2o_train,
alpha = alpha,
lambda = lambda_grid_search)
# Compare validation deviance ---------------------------------------------
dat %>%
filter(sample == "valid") %>%
mutate(pred_warmstart = as.vector(h2o.predict(glm_warmstart,
newdata = df_h2o_valid)),
pred_grid_search = as.vector(h2o.predict(glm_grid_search,
newdata = df_h2o_valid)),
deviance_warmstart = tweedie.dev(y, pred_warmstart, power),
deviance_grid_search = tweedie.dev(y, pred_grid_search, power)) %>%
summarise(
mean_deviance_warmstart = mean(deviance_warmstart), # 1.16
mean_deviance_grid_search = mean(deviance_grid_search) # 1.08
)
# Close -------------------------------------------------------------------
h2o.shutdown(prompt = FALSE)