Чтобы обучить генератор, вы должны продвинуться по всей объединенной модели, замораживая веса дискриминатора, так что обновляется только генератор.
Для этого нам нужно вычислить d(g(z; θg); θd)
,где θg и θd - веса генератора и дискриминатора.Чтобы обновить генератор, мы можем вычислить градиент относительно.до θg только ∂loss(d(g(z; θg); θd)) / ∂θg
, а затем обновите θg с использованием нормального градиентного спуска.
В Keras это может выглядеть примерно так (при использовании функционального API):
genInput = Input(input_shape)
discriminator = ...
generator = ...
discriminator.trainable = True
discriminator.compile(...)
discriminator.trainable = False
combined = Model(genInput, discriminator(generator(genInput)))
combined.compile(...)
При установке значения trainable
в False уже скомпилированные модели не затрагиваются, только модели, скомпилированные вбудущее заморожено.Таким образом, дискриминатор может быть обучен как отдельная модель, но заморожен в комбинированной модели.
Затем, чтобы обучить вашу GAN:
X_real = ...
noise = ...
X_gen = generator.predict(noise)
# This will only train the discriminator
loss_real = discriminator.train_on_batch(X_real, one_out)
loss_fake = discriminator.train_on_batch(X_gen, zero_out)
d_loss = 0.5 * np.add(loss_real, loss_fake)
noise = ...
# This will only train the generator.
g_loss = self.combined.train_on_batch(noise, one_out)