Значение образца пиро из распределения Gumbel Softmax - PullRequest
0 голосов
/ 31 января 2020
sampled_indexes = pyro.sample(f"{address}_{index}", pyro.distributions.RelaxedOneHotCategoricalStraightThrough(1, logits=char_dist), obs=observed[index])

Я использую VAE для генерации имен seq2seq, и это не удалось, я полагаю, потому что я выбираю из категориального распределения, которое нельзя поддержать. Поэтому я пытаюсь реализовать Gumbel Softmax. У меня есть этот код, где logits - это категориальное распределение по символам, а index - наблюдаемый символ, который он пытается генерировать в seq2seq. У меня сложилось впечатление, что RelaxedOneHotCategoricalStraightThrough превратит категориальный в Gumbel Softmax, а затем сделает выборку из него, но он просто сэмплирует категориальное распределение с помощью установки одного символа в 1. Любой знает, как я могу получить это для выборки из Gumbel Softmax. вместо оператора pyro.sample? PyTorch имеет дистрибутив Gumbel Softmax, который я мог бы использовать, но не работает с pyro.sample. Любой совет?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...