Реконструкция потерь по типу регрессии Variational Autoencoder - PullRequest
1 голос
/ 12 апреля 2019

В настоящее время я работаю над вариантом Variational Autoencoder в последовательной настройке, где задача состоит в том, чтобы подогнать / восстановить последовательность реальных данных наблюдений (следовательно, это проблема регрессии).

Я построил свою модель, используя tf.keras с активным исполнением и tenorflow_probability (tfp).Следуя концепции VAE, порождающая сеть испускает параметры распределения данных наблюдений, которые я моделирую как многомерную нормаль.Поэтому выходные данные являются средними значениями и logvar прогнозируемого распределения.

Что касается тренировочного процесса, первым компонентом потери является ошибка реконструкции.Это логарифмическая вероятность истинного наблюдения, учитывая предсказанное (параметры) распределение из генеративной сети.Здесь я использую tfp.distributions, так как это быстро и удобно.

Однако после завершения обучения, отмеченного значительно меньшей величиной потерь, оказывается, что моя модель, похоже, ничего не изучает.Предсказанное значение из модели является едва ли плоским во временном измерении (напомним, что проблема является последовательной).

Тем не менее, для проверки работоспособности, когда я заменяю логарифмическую вероятность на потерю MSE (что неоправданно при работе с VAE), это дает очень хорошее согласование данных.Итак, я пришел к выводу, что с этим термином логарифмического правдоподобия должно быть что-то не так.Есть ли у кого-нибудь какая-то подсказка и / или решение для этого?

Я рассмотрел вопрос о замене вероятности логарифма на потерю кросс-энтропии, но я думаю, что это не применимо в моем случае, так как моя проблема - регрессия иданные не могут быть нормализованы в диапазоне [0,1].

Я также пытался реализовать отожженный член KL (т. е. взвешивающий член KL с константой <1) при использовании логарифмической вероятности в качестве потерь при восстановлении.Но это также не сработало. </p>

Вот мой фрагмент кода функции потери (с использованием вероятности журнала в качестве ошибки восстановления):

    import tensorflow as tf
    tfe = tf.contrib.eager
    tf.enable_eager_execution()

    import tensorflow_probability as tfp
    tfd = tfp.distributions

    def loss(model, inputs):
        outputs, _ = SSM_model(model, inputs)

        #allocate the corresponding output component
        infer_mean = outputs[:,:,:latent_dim]  #mean of latent variable from  inference net
        infer_logvar = outputs[:,:,latent_dim : (2 * latent_dim)]
        trans_mean = outputs[:,:,(2 * latent_dim):(3 * latent_dim)] #mean of latent variable from transition net
        trans_logvar = outputs[:,:, (3 * latent_dim):(4 * latent_dim)]
        obs_mean = outputs[:,:,(4 * latent_dim):((4 * latent_dim) + output_obs_dim)] #mean of observation from  generative net
        obs_logvar = outputs[:,:,((4 * latent_dim) + output_obs_dim):]
        target = inputs[:,:,2:4]

        #transform logvar to std
        infer_std = tf.sqrt(tf.exp(infer_logvar))
        trans_std = tf.sqrt(tf.exp(trans_logvar))
        obs_std = tf.sqrt(tf.exp(obs_logvar))

        #computing loss at each time step
        time_step_loss = []
        for i in range(tf.shape(outputs)[0].numpy()):
            #distribution of each module
            infer_dist = tfd.MultivariateNormalDiag(infer_mean[i],infer_std[i])
            trans_dist = tfd.MultivariateNormalDiag(trans_mean[i],trans_std[i])
            obs_dist = tfd.MultivariateNormalDiag(obs_mean[i],obs_std[i])

            #log likelihood of observation
            likelihood = obs_dist.prob(target[i]) #shape = 1D = batch_size
            likelihood = tf.clip_by_value(likelihood, 1e-37, 1)
            log_likelihood = tf.log(likelihood)

            #KL of (q|p)
            kl = tfd.kl_divergence(infer_dist, trans_dist) #shape = batch_size

            #the loss
            loss = - log_likelihood + kl
            time_step_loss.append(loss)

        time_step_loss = tf.convert_to_tensor(time_step_loss)        
        overall_loss = tf.reduce_sum(time_step_loss)
        overall_loss = tf.cast(overall_loss, dtype='float32')

        return overall_loss
...