Я хотел бы вычислить расхождение KL между 2 распределениями RelaxedOneHotCategorial в PyTorch. У меня такое чувство, что я делаю что-то не так, поскольку дивергенция KL очень высока.
import torch
from torch.distributions import RelaxedOneHotCategorical
p_m = RelaxedOneHotCategorical(torch.tensor([2.2]), probs=torch.tensor([0.1, 0.2, 0.3, 0.4]))
batch_param_obtained_from_a_nn = torch.rand(2, 4)
q_m = RelaxedOneHotCategorical(torch.tensor([5.4]), logits=batch_param_obtained_from_a_nn)
z = q_m.rsample()
kl = - torch.mean(q_m.log_prob(z).exp() * (q_m.log_prob(z) - p_m.log_prob(z)))
z
tensor([[0.2671, 0.2973, 0.2144, 0.2212],
[0.2431, 0.2550, 0.3064, 0.1954]])
kl
tensor(-766.7020)
Я пропустил что-то тривиальное? Должен ли я сделать что-то особенное из https://arxiv.org/pdf/1611.00712.pdf? Я вижу, что RelaxedOneHotCategorical основан на ExpConcrete и должен обрабатывать потери значения.