Самый быстрый способ вычислить это тройное суммирование в R - PullRequest
0 голосов
/ 17 мая 2019

Моя цель - вычислить следующее тройное суммирование:

$V = \( \frac{1}{n1n2n3} \) \sum_{i=1}^{n1}\sum_{j=1}^{n2}\sum_{k=1}^{n3} I(Y_{1i},Y_{2j},Y_{3k})$

где I (Y1, Y2, Y3) определяется как:

I(Y1,Y2,Y3) = 1 if Y[1] < Y[2] < Y[3]
 I(Y1,Y2,Y3) = 1/2 if Y[1] = Y[2] < Y[3]
 I(Y1,Y2,Y3) = 1/6 if Y[1] = Y[2] = Y[3]
 I(Y1,Y2,Y3) = 0 Otherwise.

Я выполнил вычисления с помощью R, а код:

Проблема в том, что при таком способе вычисления настолько дороги. Я предполагаю, что это связано с использованием expand.grid() для создания матрицы всех комбинаций и последующего вычисления результата.

У кого-нибудь есть более эффективный способ сделать это?

set.seed(123)

nclasses = 3

ind <- function(Y){
  res = 0


if (Y[1] < Y[2] & Y[2] < Y[3]){res = 1}
  else if (Y[1] == Y[2] & Y[2] < Y[3]){res = 1/2}
  else if (Y[1] == Y[2] & Y[2] == Y[3]){res = 1/6}
  else {res = 0}

  return (res)
}

N_obs = 300
c0 <- rnorm(N_obs)
l0 = length(c0)

c1 <- rnorm(N_obs)
l1 = length(c1)

c2 <- rnorm(N_obs)
l2 = length(c2)

mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
dim(mat)

Result <- (1/(l0*l1*l2))*sum(apply(mat, 1, ind))

Ответы [ 2 ]

1 голос
/ 20 мая 2019

tl; dr - таблица данных, использующая неравные объединения, может решить ее за то же время, в течение которого tidyr завершила генерацию данных.Тем не менее, решение tidyr / dplyr выглядит лучше.

data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
  ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
    ][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                      ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
                                      )))
      ] / (length(c0) * length(c1) * length(c2))

Существует два ускорения - как данные генерируются, а затем сам расчет.

Генерация данных

Самый быстрый способ - это просто.Вместо транспонирования и распечатки вы можете использовать as.matrix для ясности и небольшого увеличения скорости.Или вы можете сохранить expand.grid как data.frame, который будет похож на решение tidyr, создающее тиббл.

Эквивалент data.table равен CJ(c0, c1, c2) и примерно в 10 раз быстрее, чемсамый быстрый базовый или тидирный эквивалент.

#Creating dataset
Unit: milliseconds
                expr     min      lq    mean  median      uq     max neval
            original 1185.10 1239.37 1478.46 1503.68 1690.47 1899.37    10
           as.matrix 1023.49 1041.72 1213.17 1198.24 1360.51 1420.78    10
         expand.grid  764.43  840.11 1030.13 1030.79 1146.82 1354.06    10
      tidyr_complete 2811.00 2948.86 3118.33 3158.59 3290.21 3364.52    10
      tidyr_crossing 1154.94 1171.01 1311.71 1233.40 1545.30 1609.86    10
       data.table_CJ  154.71  155.30  175.65  162.54  174.96  291.14    10

Другой подход заключается в использовании неравных объединений или предварительной фильтрации данных.Мы знаем, что если c0 > c1 или c1 > c2, то результат суммирования будет равен 0. Таким образом, мы можем отфильтровать комбинации, которые, как мы знаем, нам не нужно хранить в памяти, что создает комбинации быстрее.

Хотя оба эти подхода медленнее, чем data.table::CJ(), они лучше подходят для тройного суммирования.

# 'data.table_CJ_filter' = CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
#'tidyr_cross_filter' =  crossing(c0, c1) %>% filter(c0 <= c1) %>% crossing(c2) %>% filter(c1 <= c2)

#Creating dataset with future calcs in mind
Unit: milliseconds
                 expr    min     lq   mean median      uq     max neval
  data.table_non_equi 358.41 360.35 373.95 374.57  383.62  400.42    10
 data.table_CJ_filter 515.50 517.99 605.06 527.63  661.54  856.43    10
   tidyr_cross_filter 776.91 783.35 980.19 928.25 1178.47 1287.91    10

Вычисление суммы

@ Решение Джона Спринга великолепно.case_when и ifelse векторизованы, тогда как ваши оригинальные операторы if ... else не были.Я перевел ответ Джона на Базу R. Это быстрее, чем ваше первоначальное решение, но все равно занимает примерно на 50% больше времени, чем dplyr.

Одно замечание: если вы сделали неэквивалентное объединение, вы можете еще больше упроститьcase_when потому что мы уже выполнили фильтрацию - все оставшиеся строки получают 1, 1/2 или 1/6.Обратите внимание, что предварительно отфильтрованные решения примерно в 10–30 раз быстрее, чем данные, которые не были предварительно отфильтрованы.

Unit: milliseconds
             expr     min      lq    mean  median      uq     max neval
             base 5666.93 6003.87 6303.27 6214.58 6416.42 7423.30    10
            dplyr 3633.48 3963.47 4160.68 4178.15 4395.96 4530.15    10
       data.table  236.83  262.10  305.19  268.47  269.44  495.22    10
 dplyr_pre_filter  378.79  387.38  459.67  418.58  448.13  765.74    10

Сбор их вместе

Окончательное решение, представленное в началезанимает меньше секунды.Версия dplyr, которая составляет менее 2 секунд.Оба решения основаны на предварительной фильтрации перед переходом к логическому выражению if ... else.

Unit: milliseconds
      expr     min      lq    mean  median      uq    max neval
    dt_res  589.83  608.26  736.34  642.46  760.18 1091.1    10
 dt_CJ_res  750.07  764.78  905.12  893.73 1040.21 1140.5    10
 dplyr_res 1156.69 1169.84 1363.82 1337.42 1496.60 1709.8    10

Данные / код

# /9741132/samyi-bystryi-sposob-vychislit-eto-troinoe-summirovanie-v-r
library(dplyr)
library(tidyr)
library(data.table)

options(digits = 5)
set.seed(123)

nclasses = 3
N_obs = 300

c0 <- rnorm(N_obs)
c1 <- rnorm(N_obs)
c2 <- rnorm(N_obs)

# Base R Data Generation --------------------------------------------------

mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
df <- expand.grid(c0,c1,c2)

identical(mat, unname(as.matrix(df))) #TRUE - names are different with as.matrix

# tidyr and data.table Data Generation ------------------------------------

tib <- crossing(c0, c1, c2) #faster than complete

tib2 <- crossing(c0, c1)%>% #faster but similar in concept to non-equi
  filter(c0 <= c1)%>%
  crossing(c2)%>%
  filter(c1 <= c2)

dt <-   data.table(c0
                   )[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
                     ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
                       ][c0 <= c1 & c1 <= c2, ]

# Base R summation --------------------------------------------------------

sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
                      ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
                             ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
                      ))
    ) / (length(c0)*length(c1)*length(c2))


# dplyr summation ---------------------------------------------------------

tib %>%
  mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                         c0 == c1 & c1 < c2  ~ 1/2,
                         c0 == c1 & c1 == c2 ~ 1/6,
                         TRUE               ~ 0)) %>%
  summarize(mean_res = mean(res))

# data.table summation ----------------------------------------------------

#why base doesn't have case_when, who knows
dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
                ifelse(c0 == c1 & c1 < c2, 1/2,
                       ifelse(c0 == c1 & c1 == c2, 1/6)
                )))
   ] / (length(c0) * length(c1) * length(c2))


CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                             ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
                                             )))
             ] / (length(c0) * length(c1) * length(c2))

# Benchmarking ------------------------------------------------------------

library(microbenchmark)

# Data generation
microbenchmark('original' = {
  matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
}
, 'as.matrix' = {
  as.matrix(expand.grid(c0,c1,c2)) 
}
, 'expand.grid' = {
  expand.grid(c0,c1,c2) #keep it simpler
}
, 'tidyr_complete' = {
  tibble(c0, c1, c2) %>% complete(c0, c1, c2)
}
, 'tidyr_crossing' = {
  crossing(c0, c1, c2)
}
, 'data.table_CJ' = {
  CJ(c0,c1,c2)
}
, times = 10)

microbenchmark('data.table_non_equi' = {
  data.table(c0
             )[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
               ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
                 ][c0 <= c1 & c1 <= c2, ]
}
, 'data.table_CJ_filter' = {
  CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
}
, 'tidyr_cross_filter' = {
  crossing(c0,c1)%>%filter(c0 <= c1)%>% crossing(c2)%>% filter(c1 <= c2)
}
, times = 10
)

# Summation Calculation
microbenchmark('base' = {
  sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
             ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
                    ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
             ))
  ) / (length(c0)*length(c1)*length(c2))
}
, 'dplyr' = {
  tib %>%
    mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                           c0 == c1 & c1 < c2  ~ 1/2,
                           c0 == c1 & c1 == c2 ~ 1/6,
                           TRUE               ~ 0)) %>%
    summarize(mean_res = mean(res))
}
, 'data.table' = {
  dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
                  ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
                  ))
     ] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_pre_filter' = {
  tib2 %>%
    mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                           c0 == c1 & c1 < c2  ~ 1/2,
                           TRUE ~ 1/6)) %>%
    summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10)

# Start to Finish

microbenchmark('dt_res' = {
  data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
  ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
    ][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                      ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
    ))
    ] / (length(c0) * length(c1) * length(c2))
}
, 'dt_CJ_res' = {
  CJ(c0, c1, c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                                 ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
  ))
  ] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_res' = {
  crossing(c0, c1)%>% #faster but similar in concept to non-equi
    filter(c0 <= c1)%>%
    crossing(c2)%>%
    filter(c1 <= c2)%>%
    mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                           c0 == c1 & c1 < c2  ~ 1/2,
                           TRUE ~ 1/6)) %>%
    summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10
)

1 голос
/ 18 мая 2019

Оригинал занял 399 секунд на моем компьютере для выполнения строки Result <-.Эта вариация с использованием dplyr & tidyr заняла 7 секунд, чтобы выполнить суммирование, и я получил точно такой же ответ.Я предполагаю, что ускорение происходит из-за того, что версия dplyr векторизована и может выполнять одинаковые вычисления для всех 27 миллионов строк, тогда как оригинал, я подозреваю, пересчитывает что-то каждый раз.

library(dplyr); library(tidyr)

combos <- tibble(Y1 = rnorm(300),
                 Y2 = rnorm(300),
                 Y3 = rnorm(300)) %>%
  complete(Y1, Y2, Y3)

combos %>%
  mutate(res = case_when(Y1  < Y2 & Y2 < Y3  ~ 1,
                         Y1 == Y2 & Y2 < Y3  ~ 1/2,
                         Y1 == Y2 & Y2 == Y3 ~ 1/6,
                         TRUE               ~ 0)) %>%
  summarize(mean_res = mean(res))

Это кажется также разрешимым алгебраически, но я предполагаю, что смысл этого состоял в том, чтобы решить с помощью моделирования.

Если у нас есть три отдельных набора по 300 чисел длиной 16 цифр, каждый из которых нарисован с использованием rnorm, это бесконечно малый шанс того, что все совпадут друг с другом.Таким образом, мы можем игнорировать 2-й и 3-й случаи, которые не происходят с предложенным set.seed и могут потребовать миллиарды прогонов, чтобы встретиться один раз.

Теперь, как часто Y [1] set.seed(123) восходящий сценарий возникает в 22 379 120 из 27 000 000 случаев (82,9%).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...