Как эффективно распараллелить brms :: brm? - PullRequest
0 голосов
/ 04 января 2019

Краткое описание проблемы

Я подгоняю модель brms::brm_multiple() к большому набору данных, где отсутствующие данные вменяются с использованием пакета mice. Размер набора данных делает использование параллельной обработки очень желательным. Однако мне не ясно, как наилучшим образом использовать вычислительные ресурсы, потому что мне неясно, как brms делит выборку на вменяемом наборе данных между ядрами.

Как выбрать следующее, чтобы максимально эффективно использовать вычислительные ресурсы?

  • количество импутаций (m)
  • количество цепей (chains)
  • количество ядер (cores)

Концептуальный пример

Допустим, я наивно (или намеренно глупо ради примера) выбираю m = 5, chains = 10, cores = 24. Таким образом, есть 5 х 10 = 50 цепочек, которые должны быть распределены среди 24 ядер, зарезервированных на HPC. Без параллельной обработки это займет ~ 50 единиц времени (без учета времени компиляции).

Я могу представить три стратегии распараллеливания для brms_multiple(), но могут быть и другие:

Сценарий 1: Вмененные наборы данных параллельно, связанные цепочки в последовательный

Здесь каждая из 5 импутаций выделяется своему собственному процессору, который последовательно проходит 10 цепочек. Время обработки составляет 10 единиц (увеличение скорости в 5 раз по сравнению с непараллельной обработкой), но при плохом планировании было потрачено 19 ядер x 10 единиц времени = 190 единиц времени ядра (ctu; = 80% зарезервированных вычислительных ресурсов). Эффективным решением было бы установить cores = m.

Сценарий 2: Вмененные наборы данных в последовательных, связанных цепочках параллельно

Здесь выборка начинается с взятия первого вмененного набора данных и запуска одной из цепочек для этого набора данных на каждом из 10 различных ядер. Затем это повторяется для остальных четырех вмененных наборов данных. Обработка занимает 5 единиц времени (увеличение скорости в 10 раз по сравнению с последовательной обработкой и улучшение в 2 раза по сравнению со сценарием 1) Однако и здесь вычислительные ресурсы тратятся впустую: 14 ядер x 5 единиц времени = 70 куб. Эффективным решением было бы установить cores = chains

Сценарий 3: Свободно для всех, когда каждое ядро ​​принимает ожидающую комбинацию вменения / цепочки, когда она становится доступной, пока все не будут обработаны.

Здесь выборка начинается с распределения всех 24 ядер, каждое из которых находится в одной из 50 ожидающих цепочек. После того, как они завершили свои итерации, обрабатывается вторая партия из 24 цепочек, в результате чего общее число обработанных цепочек достигает 48. Но теперь осталось только две цепочки, и 22 ядра простаивают в течение 1 единицы времени. Общее время обработки составляет 3 единицы времени, а потраченный впустую вычислительный ресурс - 22 ctu. Эффективным решением было бы установить cores, кратное m x chains.

Минимальный воспроизводимый пример

Этот код сравнивает время вычислений, используя пример, модифицированный из brms виньетка . Здесь мы установим m = 10, chains = 6 и cores = 4. Таким образом, всего будет обработано 60 цепочек. В этих условиях, я ожидаю, что улучшение скорости (по сравнению с последовательной обработкой) выглядит следующим образом *:

  • Сценарий 1: 60 / (6 цепей x потолок (10 м / 4 жилы)) = 3,3x
  • Сценарий 2: 60 / (потолок (6 цепей / 4 сердечника) x 10 м) = 3,0x
  • Сценарий 3: 60 / потолок ((6 цепей х 10 м) / 4 ядра) = 4,0x

* (потолок / округление используется, потому что цепи не могут быть разделены между ядрами)

library(brms)
library(mice)
library(tictoc)  # convenience functions for timing

# Load data
data("nhanes", package = "mice")

# There are 10 imputations x 6 chains = 60 total chains to be processed
imp <- mice(nhanes, m = 10, print = FALSE, seed = 234023)

# Fit the model first to get compilation out of the way
fit_base <- brm_multiple(bmi ~ age*chl, data = imp, chains = 6,
                         iter = 10000, warmup = 2000)

# Use update() function to avoid re-compiling time
# Serial processing (127 sec on my machine)
tic()  # start timing
fit_serial <- update(fit_base, .~., cores = 1L)
t_serial <- toc()  # stop timing
t_serial <- diff(unlist(t_serial)[1:2])  # calculate seconds elapsed

# Parallel processing with 3 cores (82 sec)
tic()
fit_parallel <- update(fit_base, .~., cores = 4L)
t_parallel <- toc()
t_parallel <- diff(unlist(t_parallel)[1:2])  # calculate seconds elapsed

# Calculate speed up ratio
t_serial/t_parallel  # 1.5x

Я явно что-то упускаю. Я не могу различить сценарии с таким подходом.

...