Поиск лямбда упругой сети H2O не выбирает лямбда, которая минимизирует отклонение проверки - PullRequest
2 голосов
/ 04 июня 2019

При перекрестной проверке гиперпараметра эластичной сети lambda с использованием опции lambda_search алгоритм может не выбрать значение lambda из указанной сетки, что минимизирует отклонение в проверочном образце. Это также происходит, когда мы устанавливаем early_stopping = FALSE, то есть когда можно ожидать, что H2O оценит все значения lambda в сетке.

Этот оператор может быть проверен путем перекрестной проверки лямбды, сначала используя lambda_search = TRUE в h2o.glm(), затем выполняя поиск в сетке по тем же значениям лямбды, используя h2o.grid(), и сравнивая результирующие гиперпараметры и значения отклонения проверки. Смотрите код R ниже.

Эта проблема тесно связана с тем, что указано здесь и упоминается здесь . Этот вопрос добавляет документацию о том, что значение перекрестной проверки lambda не обязательно должно быть тем, которое минимизирует отклонение проверки. То есть проблема может быть более серьезной, чем вычисление H2O, вплоть до наилучшей лямбды и последующего выхода, как указано в комментариях здесь . Проблема возникла при настройке одного образца проверки в Tweedie glm со ссылкой на журнал, я не уверен, насколько он специфичен для этого параметра.

На основании этих результатов я бы всегда использовал поиск по сетке для определения lambda. Это уместно? В качестве альтернативы, есть ли в h2o.glm() опция, которая решает проблему с lambda_search?

rm(list = ls())
library(h2o)
library(tweedie)
library(tidyverse)

# Configuration -----------------------------------------------------------
# DGP:
n = 1000
k = 10
phi = 1
const = 0
bet = seq(-1, 1, length.out = k)
power = 1.5

# algorithm
alpha = 0.5

# Generate some data ------------------------------------------------------
set.seed(42)

x = rnorm(n * k) %>% 
  matrix(nrow = n, dimnames = list(NULL, paste0("x", seq(1, k))))
mu = as.numeric(exp(const + x %*% bet))

dat = x %>% 
  as_tibble() %>% 
  mutate(mu = mu,
         y  = rtweedie(n, 
                       mu = mu,
                       phi = phi, 
                       power = power),
         id = row_number(),
         sample = case_when(
           id <= n / 2 ~ "train",
           TRUE ~ "valid"))

# Initialize H2O ----------------------------------------------------------
h2o.init()

df_h2o_train = dat %>% 
  filter(sample == "train") %>% 
  as.h2o()

df_h2o_valid = dat %>% 
  filter(sample == "valid") %>% 
  as.h2o()


# Tune lambda -------------------------------------------------------------
# 1. Lambda search
glm_warmstart = h2o.glm(
  x                      = paste0("x", seq(1, k)),
  y                      = "y",
  family                 = "tweedie",
  tweedie_variance_power = power,
  tweedie_link_power     = 0,
  training_frame         = df_h2o_train,
  validation_frame       = df_h2o_valid,
  alpha                  = alpha,
  lambda_search          = TRUE,
  early_stopping         = FALSE
)

lambda_warmstart = glm_warmstart@model$lambda_best 
print(lambda_warmstart) # 0.1501327

# 2. Grid search
hyper_params = list(lambda = glm_warmstart@model$scoring_history$lambda %>% 
                      h2o.asnumeric())

grid_search = h2o.grid("glm",
                       hyper_params           = hyper_params,
                       x                      = paste0("x", seq(1, k)),
                       y                      = "y",
                       family                 = "tweedie",
                       tweedie_variance_power = power,
                       tweedie_link_power     = 0,
                       training_frame         = df_h2o_train,
                       validation_frame       = df_h2o_valid,
                       alpha                  = alpha,
                       lambda_search          = FALSE)

lambda_grid_search = grid_search@summary_table %>% 
  as_tibble() %>%
  head(1) %>% 
  pull(lambda) %>% 
  stringr::str_sub(2, -2) %>% 
  as.numeric()
print(lambda_grid_search) # 0.013

glm_grid_search = h2o.glm(
  x                      = paste0("x", seq(1, k)),
  y                      = "y",
  family                 = "tweedie",
  tweedie_variance_power = power,
  tweedie_link_power     = 0,
  training_frame         = df_h2o_train,
  alpha                  = alpha,
  lambda                 = lambda_grid_search)

# Compare validation deviance ---------------------------------------------
dat %>% 
  filter(sample == "valid") %>% 
  mutate(pred_warmstart = as.vector(h2o.predict(glm_warmstart,
                                             newdata = df_h2o_valid)),
         pred_grid_search  = as.vector(h2o.predict(glm_grid_search,
                                             newdata = df_h2o_valid)),
         deviance_warmstart = tweedie.dev(y, pred_warmstart, power),
         deviance_grid_search = tweedie.dev(y, pred_grid_search, power)) %>% 
  summarise(
    mean_deviance_warmstart = mean(deviance_warmstart), # 1.16
    mean_deviance_grid_search = mean(deviance_grid_search) # 1.08
  )

# Close -------------------------------------------------------------------
h2o.shutdown(prompt = FALSE)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...