NotImplementedError: Невозможно преобразовать ограничения _Boolean Невозможно преобразовать ограничения _Boolean при запуске svi.step () в pyro (pyro-ppl) - PullRequest
0 голосов
/ 20 февраля 2019

Я пытаюсь реализовать байесовскую сеть в пиро (pyro-ppl).Вот мой код модели:

import torch,pyro
from pyro.infer import SVI, Trace_ELBO
from torch.distributions.constraints import unit_interval,boolean,interval
import pyro.contrib.autoguide as ag
import pyro.optim as opt
from pyro.distributions import Bernoulli

def x_model():
    pT = pyro.param("pT", torch.tensor(.1),        constraint=unit_interval)
    pF = pyro.param("pF", torch.tensor(.0001),     constraint=unit_interval)
    pS = pyro.param("pS", torch.tensor([.1,.9]),   constraint=unit_interval)
    pA = pyro.param("pA", torch.tensor([[.0001,.85],[.99,.5]]),constraint=unit_interval)
    pL = pyro.param("pL", torch.tensor([.001,.88]),constraint=unit_interval)
    pR = pyro.param("pR", torch.tensor([.01,.75]), constraint=unit_interval)

    T  = pyro.sample("T", Bernoulli(0.1))
    F  = pyro.sample("F", Bernoulli(.0001))
    S  = pyro.sample("S", 1) #S=TRUE
    A  = pyro.sample("A", Bernoulli(pA[T.long(),F.long()]))
    L  = pyro.sample("L", Bernoulli(pL[A.long()]))
    R  = pyro.sample("R", 1) #R=TRUE

# Inference for P(A|S,R)
svi = SVI(p_model,                              \
          guide=ag.AutoDiagonalNormal(p_model), \
          optim=opt.Adam({'lr': .001}),     \
          loss =Trace_ELBO())

for step in range(2000):
    svi.step()
    print('.', end='')

Во время фазы вывода я получаю эту ошибку:

KeyError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/torch/distributions/constraint_registry.py in __call__(self,constraint)
    138         try:
--> 139             factory = self._registry[type(constraint)]
    140         except KeyError:

KeyError: <class 'torch.distributions.constraints._Boolean'>

Я пытался изменить мои ограничения в параметрах на интервал или полностью удалить ограничение на нетпомогло.Это выполняется в блокноте Google Colab.

...