Как вычислить KL-Divergence RelaxedOneHotКатегория - PullRequest
1 голос
/ 03 апреля 2019

Я хотел бы вычислить расхождение KL между 2 распределениями RelaxedOneHotCategorial в PyTorch. У меня такое чувство, что я делаю что-то не так, поскольку дивергенция KL очень высока.

import torch
from torch.distributions import RelaxedOneHotCategorical


p_m = RelaxedOneHotCategorical(torch.tensor([2.2]), probs=torch.tensor([0.1, 0.2, 0.3, 0.4]))
batch_param_obtained_from_a_nn = torch.rand(2, 4)
q_m = RelaxedOneHotCategorical(torch.tensor([5.4]), logits=batch_param_obtained_from_a_nn)

z = q_m.rsample()

kl = - torch.mean(q_m.log_prob(z).exp() * (q_m.log_prob(z) - p_m.log_prob(z)))

z                                                                                                                                                                                                                                                                       
tensor([[0.2671, 0.2973, 0.2144, 0.2212],
        [0.2431, 0.2550, 0.3064, 0.1954]])

kl                                                                                                                                                                                                                                                                      
tensor(-766.7020)

Я пропустил что-то тривиальное? Должен ли я сделать что-то особенное из https://arxiv.org/pdf/1611.00712.pdf? Я вижу, что RelaxedOneHotCategorical основан на ExpConcrete и должен обрабатывать потери значения.

...