Как мне обучить мою генераторную сеть больше, чем мой дискриминатор? - PullRequest
0 голосов
/ 19 марта 2020

Я пытаюсь ознакомиться с GAN и нашел пример использования базы данных mnist. Обучение выполняется в коде l oop, но я хочу тренировать дискриминатор меньше, чем генератор. Однако я не могу заставить это работать. Я попытался сделать отдельные циклы для дискриминатора и генератора, но это не работает. Кто-нибудь, имеющий опыт работы с GAN, может рассказать о том, как это делается? Полный код можно увидеть на https://medium.com/analytics-vidhya/implementing-a-gan-in-keras-d6c36bc6ab5f.

для эпох в диапазоне (эпох):

for batch in range(steps_per_epoch):
    noise = np.random.normal(0, 1, size=(batch_size, noise_dim))
    fake_x = generator.predict(noise)

    real_x = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

    x = np.concatenate((real_x, fake_x))

    disc_y = np.zeros(2*batch_size)
    disc_y[:batch_size] = 0.9

    d_loss = discriminator.train_on_batch(x, disc_y)

    y_gen = np.ones(batch_size)
    g_loss = gan.train_on_batch(noise, y_gen)

print(f'Epoch: {epoch} \t Discriminator Loss: {d_loss} \t\t Generator Loss: {g_loss}')
...