sampled_indexes = pyro.sample(f"{address}_{index}", pyro.distributions.RelaxedOneHotCategoricalStraightThrough(1, logits=char_dist), obs=observed[index])
Я использую VAE для генерации имен seq2seq, и это не удалось, я полагаю, потому что я выбираю из категориального распределения, которое нельзя поддержать. Поэтому я пытаюсь реализовать Gumbel Softmax. У меня есть этот код, где logits - это категориальное распределение по символам, а index - наблюдаемый символ, который он пытается генерировать в seq2seq. У меня сложилось впечатление, что RelaxedOneHotCategoricalStraightThrough
превратит категориальный в Gumbel Softmax, а затем сделает выборку из него, но он просто сэмплирует категориальное распределение с помощью установки одного символа в 1. Любой знает, как я могу получить это для выборки из Gumbel Softmax. вместо оператора pyro.sample
? PyTorch имеет дистрибутив Gumbel Softmax, который я мог бы использовать, но не работает с pyro.sample. Любой совет?