Постепенно угасать вес функции потери - PullRequest
0 голосов
/ 05 января 2019

Я не уверен, что это правильное место, чтобы задать этот вопрос, не стесняйтесь сказать мне, если мне нужно удалить сообщение.

Я довольно новичок в pyTorch и в настоящее время работаю с CycleGAN (реализация pyTorch) как часть моего проекта, и я понимаю большую часть реализации CycleGAN.

Я прочитал газету под названием «CycleGAN с лучшими циклами» и пытаюсь применить модификацию, упомянутую в статье. Одна из модификаций - это снижение веса согласованности циклов, которое я не знаю, как применить.

optimizer_G.zero_grad()

# Identity loss
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)

loss_identity = (loss_id_A + loss_id_B) / 2

# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

# Cycle consistency loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)

loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

# Total loss
loss_G =    loss_GAN + 
            lambda_cyc * loss_cycle + #lambda_cyc is 10
            lambda_id * loss_identity #lambda_id is 0.5 * lambda_cyc

loss_G.backward()
optimizer_G.step()

Мой вопрос: как я могу постепенно уменьшить вес потери согласованности цикла?

Любая помощь в реализации этой модификации будет принята.

Это из бумаги: Потеря согласованности цикла помогает стабилизировать тренировку на ранних стадиях, но становится препятствием для реалистичных изображений на более поздних стадиях. Мы предлагаем постепенно уменьшать вес потери согласованности цикла λ по мере прогресса обучения . Тем не менее, мы должны убедиться, что λ не уменьшается до 0, чтобы генераторы не становились неограниченными и полностью выходили из строя.

Заранее спасибо.

1 Ответ

0 голосов
/ 05 января 2019

Ниже приведена функция-прототип, которую вы можете использовать!

def loss (other params, decay params, initial_lambda, steps):
    # compute loss
    # compute cyclic loss
    # function that computes lambda given the steps
    cur_lambda  = compute_lambda(step, decay_params, initial_lamdba) 

    final_loss = loss + cur_lambda*cyclic_loss 
    return final_loss

compute_lambda функция для линейного затухания от 10 до 1e-5 с 50 шагами

def compute_lambda(step, decay_params):
    final_lambda = decay_params["final"]
    initial_lambda = decay_params["initial"]
    total_step = decay_params["total_step"]
    start_step = decay_params["start_step"]

    if (step < start_step+total_step and step>start_step):
        return initial_lambda + (step-start_step)*(final_lambda-initial_lambda)/total_step
    elif (step < start_step):
        return initial_lambda 
    else:
        return final_lambda
# Usage:
compute_lambda(i, {"final": 1e-5, "initial":10, "total_step":50, "start_step" : 50})    
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...