График GAN не сходится к хорошему результату - PullRequest
0 голосов
/ 03 апреля 2020

Я сейчас создаю GAN для генерации новых графиков с использованием PyTorch. Генератор создает новые матрицы объектов, эти выходные данные объединяются с установленной матрицей смежности, а затем передаются в нейронную сеть Graph для классификации, являются ли ее фальшивые данные или нет. Хотя этот элемент сетей графа присутствует, в целом он должен быть аналогичен использованию генератора и классификатора «как обычно».

Хотя я считаю, что эта настройка может работать, мои текущие результаты выглядят не очень хорошо , Я получаю очень низкие потери как для генератора, так и для дискриминатора, если он действительно сходится к чему-то, то это ситуация, когда дискриминатор имеет очень низкие потери, а генератор - очень большие потери. Я подозреваю, что что-то в моем обучении l oop не так.

Мои вопросы:

  • Последний слой моего генератора имеет 25 узлов. Входные данные для моего классификатора ожидают чего-то такого (5 * 5). Чтобы учесть это, я изменил свои данные так: fakeData = nn.Parameter(fakeData.view(len(adjacency), 5, 5)). Меня беспокоит то, что это каким-то образом отключит его от остальной части графика, а это означает, что я не могу вернуться в генератор. Это тот случай? Как мне это преодолеть?
  • Как все части включаются и выключаются, мне кажется странным. Я включил disc.train() в начале, чтобы убедиться, что все готовится. Мне это действительно нужно?

Я приложил код тренинга l oop и части, где обучаются генератор и дискриминатор. Я пытаюсь следовать чему-то очень похожему на: https://medium.com/ai-society/gans-from-scratch-1-a-deep-introduction-with-code-in-pytorch-and-tensorflow-cb03cdcdba0f

Если у кого-то есть какие-либо мысли по поводу вопросов или кода, или у него был опыт создания новых графиков и есть какие-либо заметки, я бы быть очень благодарным за это.

def train(args, disc, d_opt, gen, g_opt, trainData):

    # I grab my 'real' data from my dataloader:
    train_loader = DataLoader(trainData, batch_size = args.batchSize, shuffle=True)
    criterion = nn.BCELoss()
    for epoch in range(args.epochs):
        print("Epoch: ", epoch)

        discError = 0
        genError = 0

        # set my generator and discriminator to trainable:
        disc.train()
        gen.train()
        for batch in tqdm(train_loader):
            adjacency, features, tar, _ = batch

            # get targets for my 'fake' data:
            ones = [[1] for i in range(len(adjacency))]
            targetFake = torch.tensor(ones,dtype=torch.float)

            #generate some fake data from my generator
            fake_data = noise(len(adjacency))
            fakeData = gen(fake_data).detach()
            # last layer of generator is linear layer of 25 nodes,
            # graph classifier network needs it in shape of (5,5):
            fakeData = nn.Parameter(fakeData.view(len(adjacency), 5, 5)).detach()

            # train disc
            error_real, error_fake = \
                        train_discriminator(disc, d_opt, criterion, features, adjacency, tar, fakeData, targetFake)

            discError += (error_real + error_fake)

            # generate more data, do not detach!
            fake_data = noise(len(adjacency))
            fakeData = gen(fake_data)
            fakeData = nn.Parameter(fakeData.view(len(adjacency), 5, 5))

            # train gen
            error, acc = train_generator(disc, g_opt, criterion, fakeData, adjacency, tar)

            genError += error

        discError /= len(trainData)*2
        genError /= len(trainData)*2


def train_discriminator(disc, opt, loss, real_data, adj, targets_real, fake_data, fake_tar):
    opt.zero_grad()

    # prediction from real data
    prediction_real = disc(real_data, adj)
    error_real = loss(prediction_real, targets_real)

    # prediction from fake data
    prediction_fake = disc(fake_data, adj)
    error_fake = loss(prediction_fake, fake_tar)

    #calculate total error and backprop
    error = error_real+error_fake
    error.backward()

    opt.step()

    return error_real, error_fake


def train_generator(disc, opt, loss, fake_data, adj, fake_tar):
    opt.zero_grad()

    prediction_fake = disc(fake_data, adj)
    error_fake = loss(prediction_fake, fake_tar)
    error_fake.backward()

    opt.step()

    return error_fake

...