Я пытаюсь реализовать версию R этой модели .
Задача состоит в том, чтобы найти решение:
Я использую optim
. Я сделал две версии кода. Один использует data.table
другой использует dplyr
. Они дают тот же результат, но время выполнения сильно варьируется (data.table
медленнее в 3 раза). Это заставило меня задуматься о том, что было бы наиболее эффективным решением проблемы.
Пример набора данных:
my_df <- tibble(group = c("a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "a", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c"),
x = c(-7, -6, -5, -4, -3, -2, -1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 6, 7, 8, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, -15, -14, -13, -12, -11, -9, -8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 8),
y = c(8.744747, 9.404119, 10.27244, 10.55472, 8.620625, 9.407579, 8.971229, 9.658034, 9.705862, 9.25231, 9.751112, 10.77019, 10.37949, 10.50406, 9.700449, 8.715676, 9.357111, 7.672174, 8.107626, 8.517035, 7.305393, 7.241316, 5.763121, 8.454019, 8.020571, 10.29663, 5.762582, 9.417904, 8.921259, 9.549928, 7.744207, 9.493325, 8.178336, 8.170123, 9.129404, 10.40668, 5.865832, 5.828569, 5.864476, 5.796099, 5.978162, 9.335001, 8.831925, 9.387743, 8.85446, 8.624667, 9.190999, 9.151558, 9.012847, 8.575204, 7.646965, 7.476492, 5.73551, 7.353202, 5.78884, 6.867851, 5.223194, 6.747674, 5.188521, 5.537332, 5.372287, 5.375688))
(В моих реальных данных имеется около 4000 групп для примерно 40000 наблюдений.)
Для начальных параметров я просто подбираю каждую группу с lm
, чтобы получить $ a_i $, $ b_i $. Эти данные я использую в качестве данных для 2-го lm
, чтобы получить $ A $ и $ B $:
optim_init <- function(df_in) {
# get the a_i, b_i parameters
df_out <- df_in %>%
group_by(group) %>%
do(model = lm(y ~ x, data = .)) %>%
broom::tidy(model) %>%
ungroup %>%
dplyr::select(group, term, estimate) %>%
mutate(term =
case_when(
term == "(Intercept)" ~ "a",
term == "x" ~ "b")) %>%
pivot_wider(names_from = term, values_from = estimate)
# starting parameter fit for A, B
nlm_fit2 <- broom::tidy(lm(a ~ b, df_out)) %>%
mutate(term = ifelse(test = (term == "b"), yes = "B", no = "A"))
A <- nlm_fit2[(nlm_fit2$term == "A"),]$estimate
B <- nlm_fit2[(nlm_fit2$term == "B"),]$estimate
df_out <- df_out %>%
mutate(A = !!A,
B = !!B)
# and return the final df
return(df_out)
}
Я превращаю это в par
вектор для optim
путем:
optim_par <- function(df_in, term = a){
vec_out <- eval(substitute(term), df_in)
vec_out <- append(vec_out, unique(df_in$A))
vec_out <- append(vec_out, unique(df_in$B))
return(vec_out)
}
А фактическая функция для вычисления квадрата ошибки на основе data.frames
:
optim_fn_df <- function(par, df, df_init){
#matche the parameter vector to categories
a <- par[1 : nrow(df_init)]
A <- par[length(par)-1]
B <- par[length(par)]
df_init$a <- a
df_init$A <- A
df_init$B <- B
# calculate the error-squared
df %>%
left_join(df_init, by = c("group")) %>%
mutate(error = (a + (a-A)/B * x - y)^2) %>%
summarise(error = sum(error, na.rm = TRUE)) %>%
as.double()
}
Выполнение
my_init <- optim_init(my_df)
my_par <- optim_par(my_init)
tictoc::tic()
optim(par = my_par,
optim_fn_a,
df = my_df,
df_init = my_init,
method = "BFGS")
tictoc::toc()
заканчивается на моем компьютере примерно через 1,5 секунды.
Я попытался изменить свою функцию optim_fn
на data.tables
.
optim_fn_dt <- function(par, dt, dt_init){
a <- par[1 : nrow(dt_init)]
A <- par[length(par)-1]
B <- par[length(par)]
dt_init$a <- a
dt_init$A <- A
dt_init$B <- B
# calculate the error squared
dt <- merge(dt, dt_init)
dt[,error := (a + (a-A)/B * x - y)^2][,sum(error)]
}
, а запуск
dt_init <- data.table(df_init)
my_dt <- data.table(my_df)
optim(par = my_par,
optim_fn_dt,
dt = my_dt,
dt_init = my_init,
method = "BFGS")
занимает около 4,5 с.
Это довольно удивительно, поскольку на основе этого поста data.table::merge
примерно в 20 раз быстрее, чем dplyr::left_join
.
Так что же не так с моим optim_fn_dt
Как я могу сделать это быстрее? Я думал об использовании простых числовых матриц, но там я столкнулся с проблемой сопоставления параметров с соответствующими (x_i, y_i) парами из данных.