Контекст
Я выглядел буквально повсюду, но не смог найти полную реализацию алгоритма Метрополиса-Гастингса Random Walk с использованием логарифмической шкалы. Под логарифмической шкалой я подразумеваю, что мы работаем с логарифмом целевого распределения (который обычно будет апостериорным). Для простоты здесь у меня есть минимальный рабочий пример в R
, но я был бы рад также реализации python
.
Проблемы
Я не понимаю следующие вещи:
- Если я использую шкалу журнала, должен ли я также использовать предложение журнала?
- Если я использую шкалу журнала, имеет ли значение, если функция, которая оценивает журнал моего целевого распределения, просто пропорционально цели журнала?
- Как мне ее реализовать?
Минимальный рабочий пример контекста
Байесовская логистика c Регрессия (с 0-1
метками) log-posterior задается как
Я хочу сделать выборку из этого апостериорного значения, используя Random- Алгоритм Метрополиса-Гастингса. Согласно сообщениям, таким как это и это и это сообщение в блоге лучше использовать задний лог .
Минимальный рабочий пример - функция Метрополиса-Гастингса в случайном порядке в R
library(MASS)
rwmh_log <- function(start, niter, logtarget){
# Set current z to the initial point and calculate its log target to save computations
z <- start # It's a column vector
pz <- logtarget(start)
# Generate matrix containing the samples. Initialize first sample with the starting value
samples <- matrix(0, nrow=niter, ncol=nrow(start))
# Generate uniform random numbers in advance, to save computation. Take logarithm?
log_u <- log(runif(niter))
# Proposal is a multivariate standard normal distribution. Generate samples and
# later on use linearity property of Gaussian distribution
normal_shift <- mvrnorm(n=niter, mu=c(0,0,0), Sigma=diag(nrow(start)))
for (i in 2:niter){
# Sample a candidate
candidate <- z + normal_shift[i, ]
# calculate log target of candidate and store it in case it gets accepted
p_candidate <- logtarget(candidate)
# use decision rule explained in blog posts
if (log_u[i] <= pz - p_candidate){
# Accept!
z <- candidate
pz <- p_candidate
}
# Finally add the sample to our matrix of samples
samples[i, ] <- z
}
return(samples)
}
Logisti c Генерация данных регрессии для запуска минимального рабочего примера
set.seed(123)
# Number of observations, mean and variance-covariance matrix for class 0
n1 <- 100
m1 <- c(6, 6)
s1 <- matrix(c(1, 0, 0, 10), nrow=2, ncol=2)
# Number of observations, mean and variance-covariance matrix for class 1
n2 <- 100
m2 <- c(-1, 1)
s2 <- matrix(c(1, 0, 0, 10), nrow=2, ncol=2)
# Generate explanatory data by sampling bivariate normal distributions
class1 <- mvrnorm(n1, m1, s1)
class2 <- mvrnorm(n2, m2, s2)
# Generate class labels. First n1 are of class 0, last n2 are of class 1
y <- c(rep(0, n1), rep(1, n2))
X <- rbind(class1, class2)
График сгенерированных данных набор (просто для удовольствия)
library(ggplot2)
data <- data.frame(X, y)
ggplot(data=data, aes(x=X1, y=X2, color=as_factor(y))) +
geom_point() +
theme(plot.title=element_text(hjust=0.5, size=20)) +
labs(color="Class", title="Linearly Separable Dataset")
Logisti c задний журнал регрессии для запуска минимального рабочего примера
# NOTICE THAT X NEEDS A COLUMN OF 1s FOR THE BIAS
X <- cbind(1, X)
log_posterior_unnormalized <- function(beta){
log_prior <- -0.5*sum(beta^2)
log_likelihood <- -sum(log(1 + exp((1 - 2*y) * (X %*% beta))))
return(log_prior + log_likelihood)
}
Выполнение минимального рабочего примера
start <- matrix(0, nrow=3, ncol=1) # 3 because of the bias
niter <- 2000
samples <- rwmh_log(start=start, niter=niter, logtarget=log_posterior_unnormalized)
Результаты минимального рабочего примера
Это ошибки, говорящие
Error in if (log_u[i] <= pz - p_candidate) { :
missing value where TRUE/FALSE needed
, что в основном означает, что оба pz
и p_candidate
становится -Inf
. Что происходит ??
Запустив этот алгоритм в командной строке, мы видим, что он что-то делает для первых 102
итераций и разрывов на 103
-м. Образцы, сгенерированные до этого момента, вставлены ниже. Вы можете видеть, как они взрываются.
> samples[1:102, ]
[,1] [,2] [,3]
[1,] 0.00000000 0.0000000 0.0000000
[2,] 1.41003898 0.7678739 -1.1686514
[3,] 1.41003898 0.7678739 -1.1686514
[4,] 1.41003898 0.7678739 -1.1686514
[5,] 1.12738511 1.3976872 -0.4979555
[6,] 1.12738511 1.3976872 -0.4979555
[7,] 1.12738511 1.3976872 -0.4979555
[8,] 1.68208878 2.7622693 0.2584510
[9,] 1.68208878 2.7622693 0.2584510
[10,] 1.68208878 2.7622693 0.2584510
[11,] 1.68208878 2.7622693 0.2584510
[12,] 1.48830009 3.6458290 0.5262860
[13,] 3.13261741 4.2619841 1.1795437
[14,] 3.13261741 4.2619841 1.1795437
[15,] 2.02763486 5.2880473 0.7658672
[16,] 2.02763486 5.2880473 0.7658672
[17,] 1.96977976 6.9251877 0.6729261
[18,] -0.68979341 7.6298268 1.1032108
[19,] 0.24805648 7.7550307 1.6386097
[20,] 0.40805008 8.9963608 1.0833313
[21,] 0.10764974 9.1156086 2.8628342
[22,] -0.38268725 10.3344574 3.1492587
[23,] -0.38268725 10.3344574 3.1492587
[24,] 0.85171606 10.1059391 4.4215254
[25,] 0.30011080 11.1485574 3.7030592
[26,] 0.61490763 12.6665140 3.2527206
[27,] -0.89593039 12.3662036 5.6501731
[28,] -0.53032703 13.1415187 5.6613023
[29,] -2.97562829 13.3948230 7.2948707
[30,] -2.97562829 13.3948230 7.2948707
[31,] -2.67296962 14.8767398 7.1043539
[32,] -2.34678595 15.0634645 7.4827778
[33,] -2.28767265 15.1808296 7.7828163
[34,] -2.28767265 15.1808296 7.7828163
[35,] -2.22941911 15.9571394 7.8020756
[36,] -2.22941911 15.9571394 7.8020756
[37,] -2.22941911 15.9571394 7.8020756
[38,] -3.61326092 17.1933051 8.8868507
[39,] -3.61326092 17.1933051 8.8868507
[40,] -3.61326092 17.1933051 8.8868507
[41,] -3.61326092 17.1933051 8.8868507
[42,] -1.81956128 17.3817482 9.3416200
[43,] -3.16750568 17.4750150 10.0015226
[44,] -0.79642112 17.6399588 9.8016328
[45,] -1.32371016 18.3821248 9.1565188
[46,] -1.32371016 18.3821248 9.1565188
[47,] -2.34696374 19.0425533 9.5953375
[48,] -2.51424008 18.6796395 10.4786403
[49,] -2.51424008 18.6796395 10.4786403
[50,] -2.51424008 18.6796395 10.4786403
[51,] -3.36106702 20.5262499 11.9090427
[52,] -2.63987051 20.3130838 12.9556715
[53,] -2.63987051 20.3130838 12.9556715
[54,] -0.43401179 19.8467996 13.6708499
[55,] -1.07079158 19.0419739 14.5880249
[56,] -1.07079158 19.0419739 14.5880249
[57,] -1.07079158 19.0419739 14.5880249
[58,] -1.65936235 21.3411496 14.1030373
[59,] -0.98841697 21.9482180 14.3336541
[60,] -0.09823962 23.8173624 14.0384963
[61,] 0.02893740 24.2862548 14.9104612
[62,] 0.02893740 24.2862548 14.9104612
[63,] -0.72302884 24.7495870 15.4289650
[64,] 0.13822989 26.9019761 15.0382800
[65,] 0.13822989 26.9019761 15.0382800
[66,] 0.04357883 26.8745540 16.2482905
[67,] -0.88960758 27.9014297 16.9891905
[68,] -0.68093651 29.0824771 18.7134528
[69,] -0.43058184 29.8667399 18.7786067
[70,] -1.29463115 31.0491849 19.9036095
[71,] -1.24440787 29.5035870 21.8790285
[72,] -1.24440787 29.5035870 21.8790285
[73,] -1.55557884 31.5517166 20.5560774
[74,] -1.99585038 32.3840898 20.3167258
[75,] -1.99585038 32.3840898 20.3167258
[76,] -1.99585038 32.3840898 20.3167258
[77,] -1.89981602 33.6591331 22.0290308
[78,] -1.89981602 33.6591331 22.0290308
[79,] -2.44680063 34.6652981 22.4020355
[80,] -2.44680063 34.6652981 22.4020355
[81,] -1.60876820 35.8310586 22.4224862
[82,] -1.45422863 36.5512324 22.7365438
[83,] -1.18233421 35.5138059 24.0647585
[84,] -0.24570214 37.0486802 24.1860769
[85,] 0.03168349 37.4168439 24.8989192
[86,] 0.03168349 37.4168439 24.8989192
[87,] 0.03168349 37.4168439 24.8989192
[88,] 0.03168349 37.4168439 24.8989192
[89,] 0.24215122 37.5586010 26.5258005
[90,] 0.24215122 37.5586010 26.5258005
[91,] 0.24215122 37.5586010 26.5258005
[92,] 0.58862001 39.3607620 27.9298507
[93,] -0.27926896 40.1595364 29.2239346
[94,] -0.27926896 40.1595364 29.2239346
[95,] -0.27926896 40.1595364 29.2239346
[96,] -0.27926896 40.1595364 29.2239346
[97,] -0.17760335 40.9974928 29.4057818
[98,] 0.43759118 42.1254915 29.5706227
[99,] -0.13618441 42.4222319 29.9347374
[100,] 0.93734292 42.7479366 30.4868951
[101,] 0.93734292 42.7479366 30.4868951
[102,] 0.93734292 42.7479366 30.4868951
Глядя на значение log_posterior_unnormalized
на таких образцах, мы можем видеть, что оно действительно взрывается до бесконечности.
> apply(samples[1:107, ], 1, log_posterior_unnormalized)
[1] -138.6294 -310.8512 -310.8512 -310.8512 -814.2805 -814.2805 -814.2805 -2196.6386 -2196.6386
[10] -2196.6386 -2196.6386 -2969.5265 -3877.5393 -3877.5393 -4299.3816 -4299.3816 -5418.8160 -6098.4219
[19] -6520.9842 -7102.3063 -8222.7219 -9253.5215 -9253.5215 -9933.2233 -10196.3848 -11024.7336 -12248.1868
[28] -12816.9682 -13964.8656 -13964.8656 -14897.0547 -15285.9802 -15566.7016 -15566.7016 -16129.8982 -16129.8982
[37] -16129.8982 -17662.7135 -17662.7135 -17662.7135 -17662.7135 -18160.6234 -18610.1692 -18687.7953 -18765.9209
[46] -18765.9209 -19482.5643 -19800.5161 -19800.5161 -19800.5161 -22021.5857 -22595.2153 -22595.2153 -22838.3542
[55] -22862.6686 -22862.6686 -22862.6686 -24129.0525 -24740.6719 -25905.4325 -26827.0686 -26827.0686 -27470.8678
[64] -28779.8978 -28779.8978 -29565.7594 -30755.0321 -32768.3554 -33383.2801 -34951.9200 -35192.1701 -35192.1701
[73] -35742.4804 -36158.4290 -36158.4290 -36158.4290 -38238.1653 -38238.1653 -39191.6603 -39191.6603 -40079.6649
[82] -40818.7965 -40987.1734 -42216.0336 -42980.1015 -42980.1015 -42980.1015 -42980.1015 -44203.8974 -44203.8974
[91] -44203.8974 -46483.0324 -47911.4547 -47911.4547 -47911.4547 -47911.4547 -48647.6199 -49606.0710 -50046.5853
[100] -Inf -Inf -Inf
Это заставляет меня думать, что, возможно, правило принятия решения не работает должным образом? Может быть, он принимает сэмплы, которые он не должен принимать?
Странный результат: Должен ли я сделать выборку ОТРИЦАТЕЛЬНОЙ записи журнала?
Каким-то образом, если я использую алгоритм Random Walk Metropolis-Hastings для сэмплирования отрицательный логарифм зад Я получаю какие-то разумные результаты? Как это возможно? Вот мой код, чтобы увидеть, что происходит.
start <- matrix(0, nrow=3, ncol=1) # 3 because of the bias
niter <- 100000
samples <- rwmh_log(start=start, niter=niter, logtarget=log_posterior_unnormalized)
Если затем мы используем следующий код построения
samplesdf <- data.frame(samples) %>% mutate(rn=row_number())
trace1 <- ggplot(data=samplesdf, aes(x=rn, y=X1)) + geom_line()
trace2 <- ggplot(data=samplesdf, aes(x=rn, y=X2)) + geom_line()
trace3 <- ggplot(data=samplesdf, aes(x=rn, y=X3)) + geom_line()
grid.arrange(trace1, trace2, trace3, ncol=1)
, мы получим
и с помощью следующего кода построения мы получаем гистограммы для каждой координаты параметра
hist1 <- ggplot(data=samplesdf, aes(x=X1, stat(density))) +
geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
hist2 <- ggplot(data=samplesdf, aes(x=X2, stat(density))) +
geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
hist3 <- ggplot(data=samplesdf, aes(x=X3, stat(density))) +
geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
grid.arrange(hist1, hist2, hist3, ncol=1)
, что дает
Эти результаты можно сделать еще лучше, если начать с оценки карты и использовать обратную приближенную матрицу Гессиана в качестве матрицы дисперсии-ковариации для нормального предложения.
set.seed(123)
# Number of observations, mean and variance-covariance matrix for class 0
n1 <- 100
m1 <- c(6, 6)
s1 <- matrix(c(1, 0, 0, 10), nrow=2, ncol=2)
# Number of observations, mean and variance-covariance matrix for class 1
n2 <- 100
m2 <- c(-1, 1)
s2 <- matrix(c(1, 0, 0, 10), nrow=2, ncol=2)
# Generate explanatory data by sampling bivariate normal distributions
class1 <- mvrnorm(n1, m1, s1)
class2 <- mvrnorm(n2, m2, s2)
# Generate class labels. First n1 are of class 0, last n2 are of class 1
y <- c(rep(0, n1), rep(1, n2))
X <- rbind(class1, class2)
# Add colum of 1s for the bias
X <- cbind(1, X)
# log posterior
log_posterior_unnormalized <- function(beta){
log_prior <- -0.5*sum(beta^2)
log_likelihood <- -sum(log(1 + exp((1 - 2*y) * (X %*% beta))))
return(log_prior + log_likelihood)
}
# define negative log posterior to be minimized by optim function
logtarget <- function(x) -log_posterior_unnormalized(x)
# optimize it to find hessian and starting point
r <- optim(c(0,0,0), logtarget, method="BFGS", hessian=TRUE)
# Start at the MAP estimate
niter <- 100000
start <- r$par
z <- start
pz <- logtarget(z)
samples <- matrix(0, nrow=niter, ncol=3)
# Use the inverse of the approximate hessian matrix for our normal proposal
vcov <- solve(r$hessian)
normal_shift <- mvrnorm(n=niter, mu=c(0,0,0), Sigma=vcov)
samples[1, ] <- start
log_u <- log(runif(niter))
for (i in 2:niter){
# Sample a candidate
candidate <- z + normal_shift[i, ]
# calculate log target of candidate and store it in case it gets accepted
p_candidate <- logtarget(candidate)
# use decision rule explained in blog posts
if (log_u[i] <= pz - p_candidate){
# Accept!
z <- candidate
pz <- p_candidate
}
# Finally add the sample to our matrix of samples
samples[i, ] <- z
}
samplesdf <- data.frame(samples) %>% mutate(rn=row_number())
trace1 <- ggplot(data=samplesdf, aes(x=rn, y=X1)) + geom_line()
trace2 <- ggplot(data=samplesdf, aes(x=rn, y=X2)) + geom_line()
trace3 <- ggplot(data=samplesdf, aes(x=rn, y=X3)) + geom_line()
grid.arrange(trace1, trace2, trace3, ncol=1)
hist1 <- ggplot(data=samplesdf, aes(x=X1, stat(density))) +
geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
hist2 <- ggplot(data=samplesdf, aes(x=X2, stat(density))) +
geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
hist3 <- ggplot(data=samplesdf, aes(x=X3, stat(density))) +
geom_histogram(binwidth=0.05, alpha=0.5, fill="turquoise1", color="turquoise4")
grid.arrange(hist1, hist2, hist3, ncol=1)
, давая следующие два участки