Сложность в обучении GAN - PullRequest
1 голос
/ 12 марта 2020

Я пытаюсь обучить GAN, чтобы узнать о распределении ряда функций в событии. Обученные Дискриминатор и Генератор имеют малые потери, но сгенерированные события имеют различное распределение по форме, но я не уверен, почему.

Я определяю GAN следующим образом:

def create_generator():

    generator = Sequential()

    generator.add(Dense(50,input_dim=noise_dim))
    generator.add(LeakyReLU(0.2))    
    generator.add(Dense(25))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(5))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(len(variables), activation='tanh'))

    return generator


def create_descriminator():
    discriminator = Sequential()

    discriminator.add(Dense(4, input_dim=len(variables)))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dense(4))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dense(4))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dense(1, activation='sigmoid'))   
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return discriminator


discriminator = create_descriminator()
generator = create_generator()

def define_gan(generator, discriminator):
    # make weights in the discriminator not trainable
    discriminator.trainable = False
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    model.compile(loss = 'binary_crossentropy', optimizer=optimizer)
    return model

gan = define_gan(generator, discriminator)

И я обучаю GAN используя это l oop:

for epoch in range(epochs):
    for batch in range(steps_per_epoch):
        noise = np.random.normal(0, 1, size=(batch_size, noise_dim))
        fake_x = generator.predict(noise)

        real_x = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

        x = np.concatenate((real_x, fake_x))
        # Real events have label 1, fake events have label 0
        disc_y = np.zeros(2*batch_size)
        disc_y[:batch_size] = 1

        discriminator.trainable = True
        d_loss = discriminator.train_on_batch(x, disc_y)

        discriminator.trainable = False
        y_gen = np.ones(batch_size)
        g_loss = gan.train_on_batch(noise, y_gen)

Мои реальные события масштабируются с использованием стандартного скейлера sklearn:

scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)

Генерация событий:

X_noise = np.random.normal(0, 1, size=(n_events, GAN_noise_size))
X_generated = generator.predict(X_noise)

Когда я затем используйте обученный GAN после обучения от нескольких сотен до нескольких тысяч эпох, чтобы генерировать новые события и немасштабирование. Я получаю дистрибутивы, которые выглядят следующим образом:

enter image description here

И построение двух функций друг против друга для реальных и фальшивых событий дает: enter image description here

Это похоже на коллапс режима, но я не понимаю, как это может привести к эти экстремальные значения, где все отсекается за этими точками.

1 Ответ

1 голос
/ 27 марта 2020

Свертывание режима приводит к тому, что генератор находит несколько значений или небольшой диапазон значений, которые лучше всего могут обмануть дискриминатор. Поскольку ваш диапазон сгенерированных значений довольно узок, я полагаю, вы испытываете сбой режима. Вы можете тренироваться в течение разных периодов времени и составлять график результатов, чтобы увидеть, когда произойдет коллапс. Иногда, если вы тренируетесь достаточно долго, это исправит себя и начнет учиться снова. Существует миллиард рекомендаций о том, как обучать GAN, я собрал кучу, а затем перебрал их для каждого GAN. Вы можете попробовать тренировать дискриминатор только в каждом следующем цикле, чтобы дать генератору возможность учиться. Кроме того, несколько человек рекомендуют не тренировать дискриминатор на реальных и поддельных данных одновременно (я этого не делал, поэтому не могу сказать, какое влияние это оказывает). Вы также можете попробовать добавить несколько слоев нормализации партии. У Джейсона Браунли есть много хороших статей по обучению GAN, вы можете начать с них.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...