Как реализовать гауссову смесь для VAE? - PullRequest
1 голос
/ 24 марта 2019

Мне кажется, что я действительно не знаю, что я делаю, поэтому я опишу, что, по моему мнению, я делаю, и что я хочу сделать, и где это терпит неудачу.

Учитывая обычный вариационный автоэнкодер:

...
net = tf.layers.dense(net, units=code_size * 2, activation=None)
mean = net[:, :code_size]
std = net[:, code_size:]
posterior = tfd.MultivariateNormalDiagWithSoftplusScale(mean, std)
net = posterior.sample()
net = tf.layers.dense(net, units=input_size, ...)
...

То, что я думаю, я делаю: пусть нейронная сеть найдет значение «среднее» и «стандартное отклонение» и использует его для создания нормального распределения (гауссовского).Пример из этого распределения и использовать его для декодера.Другими словами: изучите гауссово распределение кодировки

Теперь я хотел бы сделать то же самое для смеси гауссиан.

...
net = tf.layers.dense(net, units=code_size * 2 * code_size, activation=None)

means, stds = tf.split(net, 2, axis=-1)

means = tf.split(means, code_size, axis=-1)
stds = tf.split(stds, code_size, axis=-1)

components = [tfd.MultivariateNormalDiagWithSoftplusScale(means[i], stds[i]) for i in range(code_size)]
probs = [1.0 / code_size] * code_size

gauss_mix = tfd.Mixture(cat=tfd.Categorical(probs=probs), components=components)
net = gauss_mix.sample()
net = tf.layers.dense(net, units=input_size, ...)
...

Это казалось относительно простым для меня, за исключением того, что онозавершается со следующей ошибкой:

Shapes () и (?,) несовместимы

Это, кажется, происходит от probs, который не имеет пакетаизмерение (я не думал, что это понадобится).

Я думал, что probs определяет вероятность между компонентами.

Если я определю probs, в котором также есть партияЯ получаю следующую зашифрованную ошибку. Я не знаю, что это должно означать:

Размер -1796453376 должен быть> = 0

Обычно я неправильно понимаюнекоторые понятия?

или что мне нужно сделать по-другому?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...