Какая структура наиболее эффективна для оптимальной работы со многими категориями данных в R? - PullRequest
0 голосов
/ 01 ноября 2019

Я пытаюсь реализовать версию R этой модели .

Задача состоит в том, чтобы найти решение:

eq1

Я использую optim. Я сделал две версии кода. Один использует data.table другой использует dplyr. Они дают тот же результат, но время выполнения сильно варьируется (data.table медленнее в 3 раза). Это заставило меня задуматься о том, что было бы наиболее эффективным решением проблемы.

Пример набора данных:

my_df <- tibble(group = c("a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c"),
                x = c(-7, -6, -5, -4, -3, -2, -1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 6, 7, 8, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, -15, -14, -13, -12, -11, -9, -8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 8),
                y = c(8.744747, 9.404119, 10.27244, 10.55472, 8.620625, 9.407579, 8.971229, 9.658034, 9.705862, 9.25231, 9.751112, 10.77019, 10.37949, 10.50406, 9.700449, 8.715676, 9.357111, 7.672174, 8.107626, 8.517035, 7.305393, 7.241316, 5.763121, 8.454019, 8.020571, 10.29663, 5.762582, 9.417904, 8.921259, 9.549928, 7.744207, 9.493325, 8.178336, 8.170123, 9.129404, 10.40668, 5.865832, 5.828569, 5.864476, 5.796099, 5.978162, 9.335001, 8.831925, 9.387743, 8.85446, 8.624667, 9.190999, 9.151558, 9.012847, 8.575204, 7.646965, 7.476492, 5.73551, 7.353202, 5.78884, 6.867851, 5.223194, 6.747674, 5.188521, 5.537332, 5.372287, 5.375688))

(В моих реальных данных имеется около 4000 групп для примерно 40000 наблюдений.)

Для начальных параметров я просто подбираю каждую группу с lm, чтобы получить $ a_i $, $ b_i $. Эти данные я использую в качестве данных для 2-го lm, чтобы получить $ A $ и $ B $:

optim_init <- function(df_in) {

  # get the a_i, b_i parameters
  df_out <- df_in %>% 
    group_by(group) %>% 
    do(model = lm(y ~ x, data = .)) %>% 
    broom::tidy(model) %>%
    ungroup %>%
    dplyr::select(group, term, estimate) %>%
    mutate(term =
             case_when(
               term == "(Intercept)" ~ "a",
               term == "x" ~ "b")) %>%
    pivot_wider(names_from = term, values_from = estimate)

  # starting parameter fit for A, B
  nlm_fit2 <- broom::tidy(lm(a ~ b, df_out)) %>%
    mutate(term = ifelse(test = (term == "b"), yes = "B", no = "A"))
  A <- nlm_fit2[(nlm_fit2$term == "A"),]$estimate
  B <- nlm_fit2[(nlm_fit2$term == "B"),]$estimate

  df_out <- df_out %>%
    mutate(A = !!A,
           B = !!B)
  # and return the final df
  return(df_out)
}

Я превращаю это в par вектор для optim путем:

optim_par <- function(df_in, term = a){
  vec_out <- eval(substitute(term), df_in)
  vec_out <- append(vec_out, unique(df_in$A))
  vec_out <- append(vec_out, unique(df_in$B))
  return(vec_out)
}

А фактическая функция для вычисления квадрата ошибки на основе data.frames:

optim_fn_df <- function(par, df, df_init){
  #matche the parameter vector to categories
  a <- par[1 : nrow(df_init)]
  A <- par[length(par)-1]
  B <- par[length(par)]
  df_init$a <- a
  df_init$A <- A
  df_init$B <- B

  # calculate the error-squared
    df %>%
    left_join(df_init, by = c("group")) %>%
    mutate(error = (a + (a-A)/B * x - y)^2) %>%
    summarise(error = sum(error, na.rm = TRUE)) %>%
    as.double()
}

Выполнение

my_init <- optim_init(my_df)
my_par <- optim_par(my_init)

tictoc::tic()
optim(par = my_par, 
      optim_fn_a, 
      df = my_df,
      df_init = my_init,
      method = "BFGS")
tictoc::toc()

заканчивается на моем компьютере примерно через 1,5 секунды.

Я попытался изменить свою функцию optim_fn на data.tables.

optim_fn_dt <- function(par, dt, dt_init){
  a <- par[1 : nrow(dt_init)]
  A <- par[length(par)-1]
  B <- par[length(par)]
  dt_init$a <- a
  dt_init$A <- A
  dt_init$B <- B
  # calculate the error squared
  dt <- merge(dt, dt_init) 
  dt[,error := (a + (a-A)/B * x - y)^2][,sum(error)]
}

, а запуск

dt_init <- data.table(df_init)
my_dt <- data.table(my_df)
optim(par = my_par, 
      optim_fn_dt, 
      dt = my_dt,
      dt_init = my_init,
      method = "BFGS")

занимает около 4,5 с.

Это довольно удивительно, поскольку на основе этого поста data.table::merge примерно в 20 раз быстрее, чем dplyr::left_join.

Так что же не так с моим optim_fn_dtКак я могу сделать это быстрее? Я думал об использовании простых числовых матриц, но там я столкнулся с проблемой сопоставления параметров с соответствующими (x_i, y_i) парами из данных.

...