Вы также можете взглянуть на MixtureSameFamily, которая собирает вас под одеялом.
nn_out1 = tf.expand_dims(nn_out1, axis=2)
...
outs = tf.concat([nn_out1, nn_nout2, ...], axis=2)
probs = tf.tile(tf.reduce_mean(tf.ones_like(nn_out1), axis=1, keepdims=True) / n, [1, n]) # trick to have ones of shape [None,1]
dist = tfp.distributions.MixtureSameFamily(
mixture_distribution=tfp.distributions.Categorical(probs=probs),
components_distribution=tfp.distributions.Deterministic(loc=outs))
x = dist.sample()