Потеря дискриминатора и потеря генератора не сходятся в фазе регуляризации автоконкодера Adversarial - PullRequest
3 голосов
/ 08 апреля 2019

Я пытаюсь реализовать простой автоконкодер Adversarial для набора данных MNIST с pytorch.Только кодер-декодер (без фазы регуляризации) сходится с уменьшением ошибки, как и ожидалось.Но когда я пытаюсь обучить дискриминатор и кодировщик для фазы регуляризации, потери генератора возрастают, а потери дискриминатора уменьшаются, в отличие от ожидаемых.

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.lin1 = nn.Linear(784, 400)
        self.lin2 = nn.Linear(400,100)
        self.lin3 = nn.Linear(100,2)
    def forward(self, x):
        #Without dropout
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        return x
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.lin1 = nn.Linear(2,10)
        self.lin2 = nn.Linear(10,10)
        self.lin3 = nn.Linear(10,2)
    def forward(self, x):
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        return torch.sigmoid(x)
encoder = Encoder().cuda()
discriminator = Discriminator().cuda()
encoder_gen_optimizer = optim.Adam(encoder.parameters(), lr = 3e-4)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=3e-4)
for step in range(500):
    scheduler_disc.step()
    scheduler_gen.step()
    for batch_idx, (data, target) in enumerate(data_loader):
        x = Variable(data.view(data.size(0), -1)).cuda()

        #Regularisation loss - train discriminator to detect fake distribution
        discriminator_optimizer.zero_grad()
        encoder_gen_optimizer.zero_grad()
        z_fake = encoder(x)
        D_fake = discriminator(z_fake)
        target_fake = torch.zeros(batchsize, dtype=torch.int64).cuda()
        target_real = torch.ones(batchsize, dtype=torch.int64).cuda()
        z_real = Variable(torch.randn(batchsize,2)).cuda()
        D_real = discriminator(z_real)
        disc_loss = F.cross_entropy(D_real, target_real) + F.cross_entropy(D_fake, target_fake)
        disc_loss.backward()
        discriminator_optimizer.step()

        #Train generator(encoder) to generate normal distribution
        encoder_gen_optimizer.zero_grad()
        discriminator_optimizer.zero_grad()
        z = encoder(x)
        d = discriminator(z)
        t = torch.ones(batchsize, dtype=torch.int64).cuda()
        gen_loss = F.cross_entropy(d, t)
        gen_loss.backward()
        encoder_gen_optimizer.step()
        generator_loss.append(gen_loss) 
        discriminator_loss.append(disc_loss)

Generator_Loss и Discriminator_Loss: enter image description here

...