Я борюсь с формами и индексированием с помощью структур вероятностных данных TensorFlow, особенно в отношении модуля структурных временных рядов (STS). Я построил следующую модель как JointDistributionCoroutine
:
def model_fn():
Root = tfd. JointDistributionCoroutine.Root
venue_effects = yield Root(tfd.Independent(tfp.sts.LocalLinearTrendStateSpaceModel(
num_timesteps=T,
level_scale=tf.ones([N]),
slope_scale=1.,
initial_state_prior=tfd.MultivariateNormalDiag(scale_diag=tf.ones([N, 2])),
name='venue_effects'), 1))
individual_effects = yield Root(tfd.Independent(tfd.Normal(loc=tf.zeros(P),
scale=tf.ones(P),
name="individual_effects"),
reinterpreted_batch_ndims=1))
observation_noise_scale = yield Root(tfd.HalfCauchy(loc=0, scale=1, name='observation_noise_scale'))
observed_effects = yield tfd.Independent(tfd.Normal(
loc=tf.gather(individual_effects, individual_ind)
+ tf.gather_nd(tf.squeeze(venue_effects), np.c_[venue_ind, event_week]),
scale=observation_noise_scale), 1)
model = tfd.JointDistributionCoroutine(model_fn)
Рисование одного образца из этой модели работает нормально и имеет следующие формы образцов:
[TensorShape([30, 32, 1]),
TensorShape([205]),
TensorShape([]),
TensorShape([51788])]
Однако, когда я пытаюсь передать аргумент sample_shape
, чтобы получить несколько выборок, он терпит неудачу:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-25-4fcfe66c9dd1> in <module>()
----> 1 prior_sample = model.sample(10)
13 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: indices[855] = 10 is not in [0, 10) [Op:GatherV2]
К сожалению, я не могу интерпретировать это сообщение об ошибке, за исключением тех случаев, когда оно касается моих вызовов gather
. Однако неясно, почему он работает для одного образца, а не для нескольких образцов. Может ли кто-нибудь пролить свет на это?