Dirichlet-Multinomial не обновляется правильно в Pyro? - PullRequest
0 голосов
/ 26 января 2020

Я новичок в вероятностном c программировании и пробовал несколько игрушечных примеров с Pyro. У меня возникли проблемы с реализацией SVI для Dirichlet-Multinomial, так как алгоритм, похоже, не сходится ни к чему разумному. Вот что у меня есть:


data = [10.0,12.0,2.0]

def model(data = data):
    d = torch.FloatTensor(data)
    return pyro.sample("obs", dist.DirichletMultinomial(torch.ones(len(data)), sum(data)), obs = d)

def guide(data=data):
    params = pyro.param("alphas", torch.ones(len(data)),constraint=constraints.positive)
    return pyro.sample("obs", dist.DirichletMultinomial(params,sum(data)))

adam_params = {"lr": 0.01}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
alphas, losses = [], []
n_steps = 5000
for step in range(n_steps):
  alphas.append(pyro.param("alphas"))
  losses.append(svi.step(data))

В зависимости от скорости обучения параметры сходятся либо вовсе, либо к чему-то совершенно неправильному (например, более высокие значения концентрации для более редких событий). Любой совет будет очень признателен, спасибо !!

...