Бета-вариационный AutoEncoder не может распутать - PullRequest
2 голосов
/ 05 марта 2020

Я работаю над фиктивным примером с генерируемыми сердцебиениями и хочу сначала использовать VAE для кодирования сердцебиений, а затем - простой классификатор.

Проблема в том, что когда я увеличиваю бета выше 0,01, реконструкции становятся ерунда (см. первое изображение). И когда бета низка, я получаю нормальный вывод автоэнкодера без распутывания (второе изображение). Beta=0.1Beta=0.01

Я полагаю, что проблема может заключаться в моей дивергенции KL или функции потерь VAE, но я не могу ее найти. В моем кодере я делаю репараметризацию следующим образом:

enc = self.encoder(x,batch_size, x_lenghts)
mu = self.enc2mean(enc)
logv = self.enc2logv(enc)
std = torch.exp(0.5*logv)
z = torch.randn([batch_size,1, self.encoder_hidden_sizes[-1] * (int(self.bidirectional)+1)]).to(self.device)
z = z * std + mu

И определяю потери VAE следующим образом:

def VAE_loss(x, reconstruction, mu, logvar, batch_size, latent_dim, beta=0):
    mse = F.mse_loss(x, reconstruction)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    KLD /= (batch_size * latent_dim)
    return mse + beta*KLD

Полный автономный код для воспроизведения результатов: здесь .

Любые идеи приветствуются!

...