R: Как я мог бы лучше векторизовать этот кусок кода? - PullRequest
0 голосов
/ 20 марта 2020

У меня есть фрагмент кода, который выдает p_values ​​с помощью повторной выборки Монте-Карло заданной статистики c, созданной для проверки формы независимости между переменными. Я определил с помощью профилирования 3 инструкции, которые занимают 95% времени. Сначала я дам вам пример данных и исполняемый код, а затем опишу, что делают эти строки: если у вас есть лучшие идеи, чтобы сделать это быстрее, пожалуйста, скажите мне !!

Вот MWE :

exemple_data = list(z = structure(c(0.627450980392157, 0.372549019607843, 0.637254901960784,
0.843137254901961, 0.450980392156863, 0.686274509803922, 0.0980392156862745,
0.901960784313726, 0.686274509803922, 0.666666666666667, 0.823529411764706,
0.137254901960784, 0.92156862745098, 0.784313725490196, 0.647058823529412,
0.235294117647059, 0.647058823529412, 0.431372549019608, 0.0980392156862745,
0.0196078431372549, 0.803921568627451, 0.666666666666667, 0.0980392156862745,
0.509803921568627, 0.764705882352941, 0.372549019607843, 0.431372549019608,
0.598039215686274, 0.941176470588235, 0.333333333333333, 0.0196078431372549,
0.549019607843137, 0.392156862745098, 0.490196078431373, 0.392156862745098,
0.666666666666667, 0.764705882352941, 0.0784313725490196, 0.274509803921569,
0.862745098039216, 0.196078431372549, 0.92156862745098, 0.215686274509804,
0.235294117647059, 0.529411764705882, 0.549019607843137, 0.980392156862745,
0.274509803921569, 0.392156862745098, 0.431372549019608, 0.941176470588235,
0.156862745098039, 0.862745098039216, 0.901960784313726, 0.686274509803922,
0.117647058823529, 0.882352941176471, 0.294117647058824, 0.254901960784314,
0.303921568627451, 0.588235294117647, 0.313725490196078, 0.529411764705882,
0.705882352941177, 0.705882352941177, 0.745098039215686, 0.215686274509804,
0.980392156862745, 0.803921568627451, 0.745098039215686, 0.725490196078431,
0.0784313725490196, 0.745098039215686, 0.882352941176471, 0.568627450980392,
0.519607843137255, 0.235294117647059, 0.686274509803922, 0.588235294117647,
0.843137254901961, 0.0980392156862745, 0.823529411764706, 0.117647058823529,
0.294117647058824, 0.137254901960784, 0.313725490196078, 0.96078431372549,
0.0392156862745098, 0.156862745098039, 0.549019607843137, 0.0392156862745098,
0.529411764705882, 0.666666666666667, 0.764705882352941, 0.0980392156862745,
0.411764705882353, 0.607843137254902, 0.176470588235294, 0.0196078431372549,
0.156862745098039, 0.607843137254902, 0.411764705882353, 0.882352941176471,
0.607843137254902, 0.490196078431373, 0.882352941176471, 0.176470588235294,
0.784313725490196, 0.647058823529412, 0.588235294117647, 0.980392156862745,
0.274509803921569, 0.470588235294118, 0.627450980392157, 0.941176470588235,
0.137254901960784, 0.627450980392157, 0.156862745098039, 0.117647058823529,
0.823529411764706, 0.490196078431373, 0.0392156862745098, 0.833333333333333,
0.862745098039216, 0.176470588235294, 0.92156862745098, 0.490196078431373,
0.549019607843137, 0.450980392156863, 0.92156862745098, 0.470588235294118,
0.254901960784314, 0.803921568627451, 0.823529411764706, 0.627450980392157,
0.901960784313726, 0.196078431372549, 0.725490196078431, 0.725490196078431,
0.901960784313726, 0.519607843137255, 0.509803921568627, 0.705882352941177,
0.666666666666667, 0.196078431372549, 0.274509803921569, 0.784313725490196,
0.343137254901961, 0.313725490196078, 0.372549019607843, 0.156862745098039,
0.705882352941177, 0.313725490196078, 0.411764705882353, 0.607843137254902,
0.0588235294117647, 0.588235294117647, 0.196078431372549, 0.137254901960784,
0.0784313725490196, 0.764705882352941, 0.745098039215686, 0.372549019607843,
0.372549019607843, 0.0588235294117647, 0.784313725490196, 0.862745098039216,
0.254901960784314, 0.0784313725490196, 0.245098039215686, 0.705882352941177,
0.352941176470588, 0.598039215686274, 0.431372549019608, 0.882352941176471,
0.568627450980392, 0.470588235294118, 0.509803921568627, 0.470588235294118,
0.303921568627451, 0.843137254901961, 0.450980392156863, 0.411764705882353,
0.215686274509804, 0.245098039215686, 0.647058823529412, 0.294117647058824,
0.637254901960784, 0.549019607843137, 0.725490196078431, 0.254901960784314,
0.0196078431372549, 0.96078431372549, 0.96078431372549, 0.509803921568627,
0.862745098039216, 0.117647058823529, 0.833333333333333, 0.92156862745098,
0.411764705882353, 0.215686274509804, 0.686274509803922, 0.235294117647059,
0.352941176470588, 0.470588235294118, 0.0784313725490196, 0.843137254901961,
0.343137254901961, 0.196078431372549, 0.117647058823529, 0.352941176470588,
0.0588235294117647, 0.941176470588235, 0.745098039215686, 0.274509803921569,
0.294117647058824, 0.392156862745098, 0.764705882352941, 0.980392156862745,
0.352941176470588, 0.431372549019608, 0.901960784313726, 0.137254901960784,
0.568627450980392, 0.0392156862745098, 0.96078431372549, 0.803921568627451,
0.0196078431372549, 0.0588235294117647, 0.803921568627451, 0.333333333333333,
0.568627450980392, 0.450980392156863, 0.333333333333333, 0.96078431372549,
0.450980392156863, 0.333333333333333, 0.568627450980392, 0.529411764705882,
0.215686274509804, 0.392156862745098, 0.725490196078431, 0.490196078431373,
0.0392156862745098, 0.980392156862745, 0.176470588235294, 0.941176470588235,
0.0588235294117647, 0.176470588235294, 0.784313725490196), .Dim = c(5L,
50L), .Dimnames = list(c("sr", "pop15", "pop75", "dpi", "ddpi"
), c("Australia", "Austria", "Belgium", "Bolivia", "Brazil",
"Canada", "Chile", "China", "Colombia", "Costa Rica", "Denmark",
"Ecuador", "Finland", "France", "Germany", "Greece", "Guatamala",
"Honduras", "Iceland", "India", "Ireland", "Italy", "Japan",
"Korea", "Luxembourg", "Malta", "Norway", "Netherlands", "New Zealand",
"Nicaragua", "Panama", "Paraguay", "Peru", "Philippines", "Portugal",
"South Africa", "South Rhodesia", "Spain", "Sweden", "Switzerland",
"Turkey", "Tunisia", "United Kingdom", "United States", "Venezuela",
"Zambia", "Jamaica", "Uruguay", "Libya", "Malaysia"))), bp = c(0.359829289416585,
0.565182697194852, 0.451504788384914, 0.411767914767176, 0.607858899990815
), binary_repr = structure(c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0,
0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0), .Dim = c(32L, 5L)), n = 50L, d = 5L)

library(profvis)

profvis({
  z = exemple_data$z
   bp = exemple_data$bp
   binary_repr = exemple_data$binary_repr
   n = exemple_data$n
   d = exemple_data$d
   N=999 # <<<< The number of bootstrap resamples. I need this to be like 10^3 or 10^4

   ############################### Start of the function
   # prerequisites :
   d = nrow(z)
   n = ncol(z)

   min = bp*t(binary_repr)
   max = bp^t(1-binary_repr)

   lambda_l = apply(max-min,2,prod) # 2^d
   lambda_k = vapply(1:d,function(d_rem){lambda_l/((max-min)[d_rem,])},lambda_l) # 2^d * d

   # first, we comput the empirical value of the statistic :
   core = vapply(1:2^d,function(.x){((min[,.x]<=z)*(max[,.x]>z))==1},FUN.VALUE = z) # dims : d, n, 2^d

   f_l = colMeans(apply(core,2:3,prod)) # 2^d
   f_k = vapply(1:d,function(d_rem){ colMeans(apply(core[-d_rem,,,drop=FALSE],2:3,prod))},f_l) # 2^d, d

   statistic <- sum(f_l^2/lambda_l) -2 * colSums(f_k * f_l / lambda_k) # d

   # then we bootstrap it :
   z_rep = vapply(1:N,function(i){z},z) # d, n, N

   z_repeats = vapply(1:d,function(i){
     z = z_rep
     z[i,,] = runif(n*N)
     z
   },z_rep) # d, n, N, D=d


   ####################################################### This part takes a lot of time

   cores = vapply(1:d,function(d_rem){
     vapply(1:2^d,function(.x){
       ((min[,.x]<=z_repeats[,,,d_rem])*(max[,.x]>z_repeats[,,,d_rem]))==1
     }, FUN.VALUE = z_repeats[,,,d_rem])
   }, FUN.VALUE = array(0,c(d,n,N,2^d))) # d, n, N, 2^d, d

   f_l = colMeans(apply(cores,2:5,prod))
   f_k = vapply(1:d,function(d_rem){
     plop = cores[-d_rem,,,,d_rem,drop=FALSE]
     dim(plop) <- dim(plop)[1:4]
     return(colMeans(apply(plop, 2:4, prod)))
   },array(0.,c(N,2^d))) #(N,2^d,d)

   ####################################################### How could this be done more efficiently ?

   samples = apply(aperm(f_l^2,c(2,3,1))/lambda_l - 2 * aperm(f_k*f_l,c(2,3,1))/vapply(1:N,function(i){lambda_k},lambda_k),c(2,3),sum)
   p_val = rowMeans(statistic < samples)
})

В выделенной мной части у меня есть 3 инструкции. Первый из них:

  cores = vapply(1:d,function(d_rem){
    vapply(1:2^d,function(.x){
      ((min[,.x]<=z_repeats[,,,d_rem])*(max[,.x]>z_repeats[,,,d_rem]))==1
    }, FUN.VALUE = z_repeats[,,,d_rem])
  }, FUN.VALUE = array(0,c(d,n,N,2^d))) # d, n, N, 2^d, d

Создает 5-мерный массив, который содержит простые логические значения, соответствующие min < z_repeats < max (с правильными размерами в нужном месте). Я не нашел лучшего способа вычислить это.

Затем два следующих вызова уменьшают эту переменную cores таким же образом, но с другим поднабором:

f_l = colMeans(apply(cores,2:5,prod))

не требует пояснений, а

f_k = vapply(1:d,function(d_rem){
    plop = cores[-d_rem,,,,d_rem,drop=FALSE]
    dim(plop) <- dim(plop)[1:4]
    return(colMeans(apply(plop, 2:4, prod)))
    },array(0.,c(N,2^d))) #(N,2^d,d)

делает то же самое, за исключением поднабора. Особая осторожность при настройке и переустановке размера заключается в том, что d может равняться 2 в некоторых приложениях, вызывая cores[-d_rem,,,,d_rem,drop=FALSE] или cores[-d_rem,,,,d_rem] не правильной размерности (соответственно 5 и 3, хотя мне нужно 4) .

Я начал с функции all вместо prod в этих двух вызовах, но я обнаружил, что prod намного быстрее.

Очевидно, что в этих 3 строках еще есть место для оптимизации. Но я так и не узнал, что я мог сделать больше.

Пожалуйста, скажите мне, если у вас есть идеи; Спасибо;)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...