Я работаю над фиктивным примером с генерируемыми сердцебиениями и хочу сначала использовать VAE для кодирования сердцебиений, а затем - простой классификатор.
Проблема в том, что когда я увеличиваю бета выше 0,01, реконструкции становятся ерунда (см. первое изображение). И когда бета низка, я получаю нормальный вывод автоэнкодера без распутывания (второе изображение).
Я полагаю, что проблема может заключаться в моей дивергенции KL или функции потерь VAE, но я не могу ее найти. В моем кодере я делаю репараметризацию следующим образом:
enc = self.encoder(x,batch_size, x_lenghts)
mu = self.enc2mean(enc)
logv = self.enc2logv(enc)
std = torch.exp(0.5*logv)
z = torch.randn([batch_size,1, self.encoder_hidden_sizes[-1] * (int(self.bidirectional)+1)]).to(self.device)
z = z * std + mu
И определяю потери VAE следующим образом:
def VAE_loss(x, reconstruction, mu, logvar, batch_size, latent_dim, beta=0):
mse = F.mse_loss(x, reconstruction)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
KLD /= (batch_size * latent_dim)
return mse + beta*KLD
Полный автономный код для воспроизведения результатов: здесь .
Любые идеи приветствуются!