tl; dr - таблица данных, использующая неравные объединения, может решить ее за то же время, в течение которого tidyr
завершила генерацию данных.Тем не менее, решение tidyr
/ dplyr
выглядит лучше.
data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
)))
] / (length(c0) * length(c1) * length(c2))
Существует два ускорения - как данные генерируются, а затем сам расчет.
Генерация данных
Самый быстрый способ - это просто.Вместо транспонирования и распечатки вы можете использовать as.matrix
для ясности и небольшого увеличения скорости.Или вы можете сохранить expand.grid
как data.frame, который будет похож на решение tidyr
, создающее тиббл.
Эквивалент data.table равен CJ(c0, c1, c2)
и примерно в 10 раз быстрее, чемсамый быстрый базовый или тидирный эквивалент.
#Creating dataset
Unit: milliseconds
expr min lq mean median uq max neval
original 1185.10 1239.37 1478.46 1503.68 1690.47 1899.37 10
as.matrix 1023.49 1041.72 1213.17 1198.24 1360.51 1420.78 10
expand.grid 764.43 840.11 1030.13 1030.79 1146.82 1354.06 10
tidyr_complete 2811.00 2948.86 3118.33 3158.59 3290.21 3364.52 10
tidyr_crossing 1154.94 1171.01 1311.71 1233.40 1545.30 1609.86 10
data.table_CJ 154.71 155.30 175.65 162.54 174.96 291.14 10
Другой подход заключается в использовании неравных объединений или предварительной фильтрации данных.Мы знаем, что если c0 > c1
или c1 > c2
, то результат суммирования будет равен 0. Таким образом, мы можем отфильтровать комбинации, которые, как мы знаем, нам не нужно хранить в памяти, что создает комбинации быстрее.
Хотя оба эти подхода медленнее, чем data.table::CJ()
, они лучше подходят для тройного суммирования.
# 'data.table_CJ_filter' = CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
#'tidyr_cross_filter' = crossing(c0, c1) %>% filter(c0 <= c1) %>% crossing(c2) %>% filter(c1 <= c2)
#Creating dataset with future calcs in mind
Unit: milliseconds
expr min lq mean median uq max neval
data.table_non_equi 358.41 360.35 373.95 374.57 383.62 400.42 10
data.table_CJ_filter 515.50 517.99 605.06 527.63 661.54 856.43 10
tidyr_cross_filter 776.91 783.35 980.19 928.25 1178.47 1287.91 10
Вычисление суммы
@ Решение Джона Спринга великолепно.case_when
и ifelse
векторизованы, тогда как ваши оригинальные операторы if ... else
не были.Я перевел ответ Джона на Базу R. Это быстрее, чем ваше первоначальное решение, но все равно занимает примерно на 50% больше времени, чем dplyr
.
Одно замечание: если вы сделали неэквивалентное объединение, вы можете еще больше упроститьcase_when
потому что мы уже выполнили фильтрацию - все оставшиеся строки получают 1, 1/2 или 1/6.Обратите внимание, что предварительно отфильтрованные решения примерно в 10–30 раз быстрее, чем данные, которые не были предварительно отфильтрованы.
Unit: milliseconds
expr min lq mean median uq max neval
base 5666.93 6003.87 6303.27 6214.58 6416.42 7423.30 10
dplyr 3633.48 3963.47 4160.68 4178.15 4395.96 4530.15 10
data.table 236.83 262.10 305.19 268.47 269.44 495.22 10
dplyr_pre_filter 378.79 387.38 459.67 418.58 448.13 765.74 10
Сбор их вместе
Окончательное решение, представленное в началезанимает меньше секунды.Версия dplyr
, которая составляет менее 2 секунд.Оба решения основаны на предварительной фильтрации перед переходом к логическому выражению if ... else
.
Unit: milliseconds
expr min lq mean median uq max neval
dt_res 589.83 608.26 736.34 642.46 760.18 1091.1 10
dt_CJ_res 750.07 764.78 905.12 893.73 1040.21 1140.5 10
dplyr_res 1156.69 1169.84 1363.82 1337.42 1496.60 1709.8 10
Данные / код
# /9741132/samyi-bystryi-sposob-vychislit-eto-troinoe-summirovanie-v-r
library(dplyr)
library(tidyr)
library(data.table)
options(digits = 5)
set.seed(123)
nclasses = 3
N_obs = 300
c0 <- rnorm(N_obs)
c1 <- rnorm(N_obs)
c2 <- rnorm(N_obs)
# Base R Data Generation --------------------------------------------------
mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
df <- expand.grid(c0,c1,c2)
identical(mat, unname(as.matrix(df))) #TRUE - names are different with as.matrix
# tidyr and data.table Data Generation ------------------------------------
tib <- crossing(c0, c1, c2) #faster than complete
tib2 <- crossing(c0, c1)%>% #faster but similar in concept to non-equi
filter(c0 <= c1)%>%
crossing(c2)%>%
filter(c1 <= c2)
dt <- data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, ]
# Base R summation --------------------------------------------------------
sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
))
) / (length(c0)*length(c1)*length(c2))
# dplyr summation ---------------------------------------------------------
tib %>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
c0 == c1 & c1 == c2 ~ 1/6,
TRUE ~ 0)) %>%
summarize(mean_res = mean(res))
# data.table summation ----------------------------------------------------
#why base doesn't have case_when, who knows
dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2,
ifelse(c0 == c1 & c1 == c2, 1/6)
)))
] / (length(c0) * length(c1) * length(c2))
CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
)))
] / (length(c0) * length(c1) * length(c2))
# Benchmarking ------------------------------------------------------------
library(microbenchmark)
# Data generation
microbenchmark('original' = {
matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
}
, 'as.matrix' = {
as.matrix(expand.grid(c0,c1,c2))
}
, 'expand.grid' = {
expand.grid(c0,c1,c2) #keep it simpler
}
, 'tidyr_complete' = {
tibble(c0, c1, c2) %>% complete(c0, c1, c2)
}
, 'tidyr_crossing' = {
crossing(c0, c1, c2)
}
, 'data.table_CJ' = {
CJ(c0,c1,c2)
}
, times = 10)
microbenchmark('data.table_non_equi' = {
data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, ]
}
, 'data.table_CJ_filter' = {
CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
}
, 'tidyr_cross_filter' = {
crossing(c0,c1)%>%filter(c0 <= c1)%>% crossing(c2)%>% filter(c1 <= c2)
}
, times = 10
)
# Summation Calculation
microbenchmark('base' = {
sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
))
) / (length(c0)*length(c1)*length(c2))
}
, 'dplyr' = {
tib %>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
c0 == c1 & c1 == c2 ~ 1/6,
TRUE ~ 0)) %>%
summarize(mean_res = mean(res))
}
, 'data.table' = {
dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
))
] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_pre_filter' = {
tib2 %>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
TRUE ~ 1/6)) %>%
summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10)
# Start to Finish
microbenchmark('dt_res' = {
data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
))
] / (length(c0) * length(c1) * length(c2))
}
, 'dt_CJ_res' = {
CJ(c0, c1, c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
))
] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_res' = {
crossing(c0, c1)%>% #faster but similar in concept to non-equi
filter(c0 <= c1)%>%
crossing(c2)%>%
filter(c1 <= c2)%>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
TRUE ~ 1/6)) %>%
summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10
)