Правильное получение форм при работе с STS и JointDistributionCoroutine в TensorFlow Probability - PullRequest
0 голосов
/ 18 июня 2020

Я борюсь с формами и индексированием с помощью структур вероятностных данных TensorFlow, особенно в отношении модуля структурных временных рядов (STS). Я построил следующую модель как JointDistributionCoroutine:

def model_fn():
    Root = tfd. JointDistributionCoroutine.Root
    venue_effects = yield Root(tfd.Independent(tfp.sts.LocalLinearTrendStateSpaceModel(
        num_timesteps=T,
        level_scale=tf.ones([N]),
        slope_scale=1.,
        initial_state_prior=tfd.MultivariateNormalDiag(scale_diag=tf.ones([N, 2])),
        name='venue_effects'), 1))
    individual_effects = yield Root(tfd.Independent(tfd.Normal(loc=tf.zeros(P),
                                scale=tf.ones(P),
                                name="individual_effects"),
                             reinterpreted_batch_ndims=1))
    observation_noise_scale = yield Root(tfd.HalfCauchy(loc=0, scale=1, name='observation_noise_scale'))
    observed_effects = yield tfd.Independent(tfd.Normal(
        loc=tf.gather(individual_effects, individual_ind) 
            + tf.gather_nd(tf.squeeze(venue_effects), np.c_[venue_ind, event_week]),
        scale=observation_noise_scale), 1)

model = tfd.JointDistributionCoroutine(model_fn)

Рисование одного образца из этой модели работает нормально и имеет следующие формы образцов:

[TensorShape([30, 32, 1]),
 TensorShape([205]),
 TensorShape([]),
 TensorShape([51788])]

Однако, когда я пытаюсь передать аргумент sample_shape, чтобы получить несколько выборок, он терпит неудачу:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-25-4fcfe66c9dd1> in <module>()
----> 1 prior_sample = model.sample(10)

13 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: indices[855] = 10 is not in [0, 10) [Op:GatherV2]

К сожалению, я не могу интерпретировать это сообщение об ошибке, за исключением тех случаев, когда оно касается моих вызовов gather. Однако неясно, почему он работает для одного образца, а не для нескольких образцов. Может ли кто-нибудь пролить свет на это?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...