В учебном пособии по VAE kl-расхождение двух нормальных распределений определяется как:
И во многих кодах, таких как здесь , здесь и здесь , код реализован как:
KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
или
def latent_loss(z_mean, z_stddev):
mean_sq = z_mean * z_mean
stddev_sq = z_stddev * z_stddev
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)
Как они связаны? почему в коде нет "tr" или ".transpose ()"?