Как KL-дивергенция в коде pytorch связана с формулой? - PullRequest
2 голосов
/ 04 мая 2020

В учебном пособии по VAE kl-расхождение двух нормальных распределений определяется как: enter image description here

И во многих кодах, таких как здесь , здесь и здесь , код реализован как:

 KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())

или

def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

Как они связаны? почему в коде нет "tr" или ".transpose ()"?

1 Ответ

4 голосов
/ 04 мая 2020

Выражения в размещенном вами коде предполагают, что X является некоррелированной многовариантной гауссовской случайной величиной. Это очевидно по отсутствию перекрестных членов в детерминанте ковариационной матрицы. Поэтому средний вектор и ковариационная матрица имеют вид

enter image description here

Используя это, мы можем быстро получить следующие эквивалентные представления для компонентов исходного выражения

enter image description here

Подстановка их обратно в исходное выражение дает

enter image description here

...