Важность взвешенного автоэнкодера работает хуже, чем VAE - PullRequest
0 голосов
/ 01 апреля 2020

Я применяю модели VAE и IWAE для набора данных силуэтов и у меня возникла проблема, при которой VAE превосходит IWAE со скромным запасом (тест LL ~ 120 для VAE, ~ 133 для IWAE! ). Я не верю, что это должно быть так, в соответствии с теорией и проведенными экспериментами здесь .

Я надеюсь, что кто-то может найти какую-то проблему в том, как я реализую, что вызывает это на всякий случай.

Сеть, которую я использую для аппроксимации q и p, такая же, как подробно описана в приложении к статье выше. Расчетная часть модели приведена ниже:

data_k_vec = data.repeat_interleave(K,0) # Generate K samples (in my case K=50 is producing this behavior)

mu, log_std = model.encode(data_k_vec)
z = model.reparameterize(mu, log_std) # z = mu + torch.exp(log_std)*epsilon (epsilon ~ N(0,1))
decoded = model.decode(z) # this is the sigmoid output of the model

log_prior_z = torch.sum(-0.5 * z ** 2, 1)-.5*z.shape[1]*T.log(torch.tensor(2*np.pi))
log_q_z = compute_log_probability_gaussian(z, mu, log_std) # Definitions below
log_p_x = compute_log_probability_bernoulli(decoded,data_k_vec) 

if model_type == 'iwae':
    log_w_matrix = (log_prior_z + log_p_x  - log_q_z).view(-1, K)
elif model_type =='vae':
    log_w_matrix = (log_prior_z + log_p_x  - log_q_z).view(-1, 1)*1/K

log_w_minus_max = log_w_matrix - torch.max(log_w_matrix, 1, keepdim=True)[0]
ws_matrix = torch.exp(log_w_minus_max)
ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

loss = -torch.sum(ws_sum_per_datapoint) # value of loss that gets returned to training function. loss.backward() will get called on this value

Вот функции вероятности. Мне пришлось суетиться с Bernoulli LL, чтобы не получать nan во время обучения

def compute_log_probability_gaussian(obs, mu, logstd, axis=1):
    return torch.sum(-0.5 * ((obs-mu) / torch.exp(logstd)) ** 2 - logstd, axis)-.5*obs.shape[1]*T.log(torch.tensor(2*np.pi)) 

def compute_log_probability_bernoulli(theta, obs, axis=1): # Add 1e-18 to avoid nan appearances in training
    return torch.sum(obs*torch.log(theta+1e-18) + (1-obs)*torch.log(1-theta+1e-18), axis)

В этом коде используется «быстрая комбинация», в которой весовые значения важности строки вычисляются в model_type=='iwae' в случае для K = 50 выборок в каждой строке, в то время как в случае model_type=='vae' веса важности рассчитываются для единственного значения, оставленного в каждой строке, так что в итоге вычисляется только вес 1. Может быть, это проблема?

Любая и вся помощь огромна - я думал, что решение этой проблемы навсегда избавит меня от сорняков, но теперь у меня есть эта новая проблема.

РЕДАКТИРОВАТЬ: Следует добавить, что Схема обучения такая же, как в статье, приведенной выше. То есть для каждого из i=0....7 раундов поезд для 2**i эпох с темпом обучения 1e-4 * 10**(-i/7)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...