Реализация Walk Walk Metropolis Hastings в R с использованием логарифмической шкалы - PullRequest
3 голосов
/ 13 января 2020

Контекст

Я выглядел буквально повсюду, но не смог найти полную реализацию алгоритма Метрополиса-Гастингса Random Walk с использованием логарифмической шкалы. Под логарифмической шкалой я подразумеваю, что мы работаем с логарифмом целевого распределения (который обычно будет апостериорным). Для простоты здесь у меня есть минимальный рабочий пример в R, но я был бы рад также реализации python.

Проблемы

Я не понимаю следующие вещи:

  1. Если я использую шкалу журнала, должен ли я также использовать предложение журнала?
  2. Если я использую шкалу журнала, имеет ли значение, если функция, которая оценивает журнал моего целевого распределения, просто пропорционально цели журнала?
  3. Как мне ее реализовать?

Минимальный рабочий пример контекста

Байесовская логистика c Регрессия (с 0-1 метками) log-posterior задается как

logistic regression

Я хочу сделать выборку из этого апостериорного значения, используя Random- Алгоритм Метрополиса-Гастингса. Согласно сообщениям, таким как это и это и это сообщение в блоге лучше использовать задний лог .

Минимальный рабочий пример - функция Метрополиса-Гастингса в случайном порядке в R

library(MASS)
rwmh_log <- function(start, niter, logtarget){
    # Set current z to the initial point and calculate its log target to save computations
    z  <- start    # It's a column vector
    pz <- logtarget(start)
    # Generate matrix containing the samples. Initialize first sample with the starting value
    samples <- matrix(0, nrow=niter, ncol=nrow(start))
    # Generate uniform random numbers in advance, to save computation. Take logarithm?
    log_u <- log(runif(niter))
    # Proposal is a multivariate standard normal distribution. Generate samples and
    # later on use linearity property of Gaussian distribution
    normal_shift <- mvrnorm(n=niter, mu=c(0,0,0), Sigma=diag(nrow(start)))
    for (i in 2:niter){
        # Sample a candidate
        candidate <- z + normal_shift[i, ]
        # calculate log target of candidate and store it in case it gets accepted
        p_candidate <- logtarget(candidate)
        # use decision rule explained in blog posts
        if (log_u[i] <= pz - p_candidate){
            # Accept!
            z  <- candidate
            pz <- p_candidate
        }
        # Finally add the sample to our matrix of samples
        samples[i, ] <- z
    }
    return(samples)
}

Logisti c Генерация данных регрессии для запуска минимального рабочего примера

set.seed(123)
# Number of observations, mean and variance-covariance matrix for class 0
n1 <- 100
m1 <- c(6, 6)
s1 <- matrix(c(1, 0, 0, 10), nrow=2, ncol=2)
# Number of observations, mean and variance-covariance matrix for class 1
n2 <- 100
m2 <- c(-1, 1)
s2 <- matrix(c(1, 0, 0, 10), nrow=2, ncol=2)
# Generate explanatory data by sampling bivariate normal distributions
class1 <- mvrnorm(n1, m1, s1)
class2 <- mvrnorm(n2, m2, s2)
# Generate class labels. First n1 are of class 0, last n2 are of class 1
y  <- c(rep(0, n1), rep(1, n2))
X  <- rbind(class1, class2)

График сгенерированных данных набор (просто для удовольствия)

library(ggplot2)
data <- data.frame(X, y)
ggplot(data=data, aes(x=X1, y=X2, color=as_factor(y))) + 
    geom_point() + 
    theme(plot.title=element_text(hjust=0.5, size=20)) + 
    labs(color="Class", title="Linearly Separable Dataset")

dataset

Logisti c задний журнал регрессии для запуска минимального рабочего примера

# NOTICE THAT X NEEDS A COLUMN OF 1s FOR THE BIAS
X <- cbind(1, X)
log_posterior_unnormalized <- function(beta){
    log_prior      <- -0.5*sum(beta^2)
    log_likelihood <- -sum(log(1 + exp((1 - 2*y) * (X %*% beta))))
    return(log_prior + log_likelihood)
}

Выполнение минимального рабочего примера

start    <- matrix(0, nrow=3, ncol=1) # 3 because of the bias
niter    <- 2000
samples  <- rwmh_log(start=start, niter=niter, logtarget=log_posterior_unnormalized)

Результаты минимального рабочего примера

Это ошибки, говорящие

Error in if (log_u[i] <= pz - p_candidate) { : 
  missing value where TRUE/FALSE needed

, что в основном означает, что оба pz и p_candidate становится -Inf. Что происходит ??

Запустив этот алгоритм в командной строке, мы видим, что он что-то делает для первых 102 итераций и разрывов на 103-м. Образцы, сгенерированные до этого момента, вставлены ниже. Вы можете видеть, как они взрываются.

> samples[1:102, ]
              [,1]       [,2]       [,3]
  [1,]  0.00000000  0.0000000  0.0000000
  [2,]  1.41003898  0.7678739 -1.1686514
  [3,]  1.41003898  0.7678739 -1.1686514
  [4,]  1.41003898  0.7678739 -1.1686514
  [5,]  1.12738511  1.3976872 -0.4979555
  [6,]  1.12738511  1.3976872 -0.4979555
  [7,]  1.12738511  1.3976872 -0.4979555
  [8,]  1.68208878  2.7622693  0.2584510
  [9,]  1.68208878  2.7622693  0.2584510
 [10,]  1.68208878  2.7622693  0.2584510
 [11,]  1.68208878  2.7622693  0.2584510
 [12,]  1.48830009  3.6458290  0.5262860
 [13,]  3.13261741  4.2619841  1.1795437
 [14,]  3.13261741  4.2619841  1.1795437
 [15,]  2.02763486  5.2880473  0.7658672
 [16,]  2.02763486  5.2880473  0.7658672
 [17,]  1.96977976  6.9251877  0.6729261
 [18,] -0.68979341  7.6298268  1.1032108
 [19,]  0.24805648  7.7550307  1.6386097
 [20,]  0.40805008  8.9963608  1.0833313
 [21,]  0.10764974  9.1156086  2.8628342
 [22,] -0.38268725 10.3344574  3.1492587
 [23,] -0.38268725 10.3344574  3.1492587
 [24,]  0.85171606 10.1059391  4.4215254
 [25,]  0.30011080 11.1485574  3.7030592
 [26,]  0.61490763 12.6665140  3.2527206
 [27,] -0.89593039 12.3662036  5.6501731
 [28,] -0.53032703 13.1415187  5.6613023
 [29,] -2.97562829 13.3948230  7.2948707
 [30,] -2.97562829 13.3948230  7.2948707
 [31,] -2.67296962 14.8767398  7.1043539
 [32,] -2.34678595 15.0634645  7.4827778
 [33,] -2.28767265 15.1808296  7.7828163
 [34,] -2.28767265 15.1808296  7.7828163
 [35,] -2.22941911 15.9571394  7.8020756
 [36,] -2.22941911 15.9571394  7.8020756
 [37,] -2.22941911 15.9571394  7.8020756
 [38,] -3.61326092 17.1933051  8.8868507
 [39,] -3.61326092 17.1933051  8.8868507
 [40,] -3.61326092 17.1933051  8.8868507
 [41,] -3.61326092 17.1933051  8.8868507
 [42,] -1.81956128 17.3817482  9.3416200
 [43,] -3.16750568 17.4750150 10.0015226
 [44,] -0.79642112 17.6399588  9.8016328
 [45,] -1.32371016 18.3821248  9.1565188
 [46,] -1.32371016 18.3821248  9.1565188
 [47,] -2.34696374 19.0425533  9.5953375
 [48,] -2.51424008 18.6796395 10.4786403
 [49,] -2.51424008 18.6796395 10.4786403
 [50,] -2.51424008 18.6796395 10.4786403
 [51,] -3.36106702 20.5262499 11.9090427
 [52,] -2.63987051 20.3130838 12.9556715
 [53,] -2.63987051 20.3130838 12.9556715
 [54,] -0.43401179 19.8467996 13.6708499
 [55,] -1.07079158 19.0419739 14.5880249
 [56,] -1.07079158 19.0419739 14.5880249
 [57,] -1.07079158 19.0419739 14.5880249
 [58,] -1.65936235 21.3411496 14.1030373
 [59,] -0.98841697 21.9482180 14.3336541
 [60,] -0.09823962 23.8173624 14.0384963
 [61,]  0.02893740 24.2862548 14.9104612
 [62,]  0.02893740 24.2862548 14.9104612
 [63,] -0.72302884 24.7495870 15.4289650
 [64,]  0.13822989 26.9019761 15.0382800
 [65,]  0.13822989 26.9019761 15.0382800
 [66,]  0.04357883 26.8745540 16.2482905
 [67,] -0.88960758 27.9014297 16.9891905
 [68,] -0.68093651 29.0824771 18.7134528
 [69,] -0.43058184 29.8667399 18.7786067
 [70,] -1.29463115 31.0491849 19.9036095
 [71,] -1.24440787 29.5035870 21.8790285
 [72,] -1.24440787 29.5035870 21.8790285
 [73,] -1.55557884 31.5517166 20.5560774
 [74,] -1.99585038 32.3840898 20.3167258
 [75,] -1.99585038 32.3840898 20.3167258
 [76,] -1.99585038 32.3840898 20.3167258
 [77,] -1.89981602 33.6591331 22.0290308
 [78,] -1.89981602 33.6591331 22.0290308
 [79,] -2.44680063 34.6652981 22.4020355
 [80,] -2.44680063 34.6652981 22.4020355
 [81,] -1.60876820 35.8310586 22.4224862
 [82,] -1.45422863 36.5512324 22.7365438
 [83,] -1.18233421 35.5138059 24.0647585
 [84,] -0.24570214 37.0486802 24.1860769
 [85,]  0.03168349 37.4168439 24.8989192
 [86,]  0.03168349 37.4168439 24.8989192
 [87,]  0.03168349 37.4168439 24.8989192
 [88,]  0.03168349 37.4168439 24.8989192
 [89,]  0.24215122 37.5586010 26.5258005
 [90,]  0.24215122 37.5586010 26.5258005
 [91,]  0.24215122 37.5586010 26.5258005
 [92,]  0.58862001 39.3607620 27.9298507
 [93,] -0.27926896 40.1595364 29.2239346
 [94,] -0.27926896 40.1595364 29.2239346
 [95,] -0.27926896 40.1595364 29.2239346
 [96,] -0.27926896 40.1595364 29.2239346
 [97,] -0.17760335 40.9974928 29.4057818
 [98,]  0.43759118 42.1254915 29.5706227
 [99,] -0.13618441 42.4222319 29.9347374
[100,]  0.93734292 42.7479366 30.4868951
[101,]  0.93734292 42.7479366 30.4868951
[102,]  0.93734292 42.7479366 30.4868951

Глядя на значение log_posterior_unnormalized на таких образцах, мы можем видеть, что оно действительно взрывается до бесконечности.

> apply(samples[1:107, ], 1, log_posterior_unnormalized)
  [1]   -138.6294   -310.8512   -310.8512   -310.8512   -814.2805   -814.2805   -814.2805  -2196.6386  -2196.6386
 [10]  -2196.6386  -2196.6386  -2969.5265  -3877.5393  -3877.5393  -4299.3816  -4299.3816  -5418.8160  -6098.4219
 [19]  -6520.9842  -7102.3063  -8222.7219  -9253.5215  -9253.5215  -9933.2233 -10196.3848 -11024.7336 -12248.1868
 [28] -12816.9682 -13964.8656 -13964.8656 -14897.0547 -15285.9802 -15566.7016 -15566.7016 -16129.8982 -16129.8982
 [37] -16129.8982 -17662.7135 -17662.7135 -17662.7135 -17662.7135 -18160.6234 -18610.1692 -18687.7953 -18765.9209
 [46] -18765.9209 -19482.5643 -19800.5161 -19800.5161 -19800.5161 -22021.5857 -22595.2153 -22595.2153 -22838.3542
 [55] -22862.6686 -22862.6686 -22862.6686 -24129.0525 -24740.6719 -25905.4325 -26827.0686 -26827.0686 -27470.8678
 [64] -28779.8978 -28779.8978 -29565.7594 -30755.0321 -32768.3554 -33383.2801 -34951.9200 -35192.1701 -35192.1701
 [73] -35742.4804 -36158.4290 -36158.4290 -36158.4290 -38238.1653 -38238.1653 -39191.6603 -39191.6603 -40079.6649
 [82] -40818.7965 -40987.1734 -42216.0336 -42980.1015 -42980.1015 -42980.1015 -42980.1015 -44203.8974 -44203.8974
 [91] -44203.8974 -46483.0324 -47911.4547 -47911.4547 -47911.4547 -47911.4547 -48647.6199 -49606.0710 -50046.5853
[100]        -Inf        -Inf        -Inf

Это заставляет меня думать, что, возможно, правило принятия решения не работает должным образом? Может быть, он принимает сэмплы, которые он не должен принимать?

Странный результат: Должен ли я сделать выборку ОТРИЦАТЕЛЬНОЙ записи журнала?

Каким-то образом, если я использую алгоритм Random Walk Metropolis-Hastings для сэмплирования отрицательный логарифм зад Я получаю какие-то разумные результаты? Как это возможно? Вот мой код, чтобы увидеть, что происходит.

start    <- matrix(0, nrow=3, ncol=1) # 3 because of the bias
niter    <- 100000
samples  <- rwmh_log(start=start, niter=niter, logtarget=log_posterior_unnormalized)

Если затем мы используем следующий код построения

samplesdf <- data.frame(samples) %>% mutate(rn=row_number())
trace1 <- ggplot(data=samplesdf, aes(x=rn, y=X1)) + geom_line()
trace2 <- ggplot(data=samplesdf, aes(x=rn, y=X2)) + geom_line()
trace3 <- ggplot(data=samplesdf, aes(x=rn, y=X3)) + geom_line()
grid.arrange(trace1, trace2, trace3, ncol=1)

, мы получим

traceplot

и с помощью следующего кода построения мы получаем гистограммы для каждой координаты параметра

hist1 <- ggplot(data=samplesdf, aes(x=X1, stat(density))) + 
  geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
hist2 <- ggplot(data=samplesdf, aes(x=X2, stat(density))) + 
  geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
hist3 <- ggplot(data=samplesdf, aes(x=X3, stat(density))) + 
  geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
grid.arrange(hist1, hist2, hist3, ncol=1)

, что дает

histograms

Эти результаты можно сделать еще лучше, если начать с оценки карты и использовать обратную приближенную матрицу Гессиана в качестве матрицы дисперсии-ковариации для нормального предложения.

set.seed(123)
# Number of observations, mean and variance-covariance matrix for class 0
n1 <- 100
m1 <- c(6, 6)
s1 <- matrix(c(1, 0, 0, 10), nrow=2, ncol=2)
# Number of observations, mean and variance-covariance matrix for class 1
n2 <- 100
m2 <- c(-1, 1)
s2 <- matrix(c(1, 0, 0, 10), nrow=2, ncol=2)
# Generate explanatory data by sampling bivariate normal distributions
class1 <- mvrnorm(n1, m1, s1)
class2 <- mvrnorm(n2, m2, s2)
# Generate class labels. First n1 are of class 0, last n2 are of class 1
y  <- c(rep(0, n1), rep(1, n2))
X  <- rbind(class1, class2)
# Add colum of 1s for the bias
X <- cbind(1, X)
# log posterior
log_posterior_unnormalized <- function(beta){
  log_prior      <- -0.5*sum(beta^2)
  log_likelihood <- -sum(log(1 + exp((1 - 2*y) * (X %*% beta))))
  return(log_prior + log_likelihood)
}
# define negative log posterior to be minimized by optim function
logtarget <- function(x) -log_posterior_unnormalized(x)
# optimize it to find hessian and starting point
r <- optim(c(0,0,0), logtarget, method="BFGS", hessian=TRUE)
# Start at the MAP estimate
niter <- 100000
start <- r$par
z <- start
pz <- logtarget(z)
samples <- matrix(0, nrow=niter, ncol=3)
# Use the inverse of the approximate hessian matrix for our normal proposal
vcov <- solve(r$hessian) 
normal_shift <- mvrnorm(n=niter, mu=c(0,0,0), Sigma=vcov)
samples[1, ] <- start
log_u <- log(runif(niter))
for (i in 2:niter){
  # Sample a candidate
  candidate <- z + normal_shift[i, ]
  # calculate log target of candidate and store it in case it gets accepted
  p_candidate <- logtarget(candidate)
  # use decision rule explained in blog posts
  if (log_u[i] <= pz - p_candidate){
    # Accept!
    z  <- candidate
    pz <- p_candidate
  }
  # Finally add the sample to our matrix of samples
  samples[i, ] <- z
}


samplesdf <- data.frame(samples) %>% mutate(rn=row_number())
trace1 <- ggplot(data=samplesdf, aes(x=rn, y=X1)) + geom_line()
trace2 <- ggplot(data=samplesdf, aes(x=rn, y=X2)) + geom_line()
trace3 <- ggplot(data=samplesdf, aes(x=rn, y=X3)) + geom_line()
grid.arrange(trace1, trace2, trace3, ncol=1)

hist1 <- ggplot(data=samplesdf, aes(x=X1, stat(density))) + 
  geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
hist2 <- ggplot(data=samplesdf, aes(x=X2, stat(density))) + 
  geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
hist3 <- ggplot(data=samplesdf, aes(x=X3, stat(density))) + 
  geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
grid.arrange(hist1, hist2, hist3, ncol=1)

, давая следующие два участки

traceplot2

histograms2

...