Как я понимаю, разница между обычным GAN и WGAN заключается в том, что мы обучаем дискриминатор / критики c с большим количеством примеров в каждой эпохе. Если в обычном gan у нас в каждой эпохе по одному пакету для обоих модулей, в WGAN у нас будет 5 пакетов (или более ..) для дискриминатора и один для генератора.
Так что в основном у нас есть другой внутренний oop для дискриминатора:
real_images_labels = np.ones((BATCH_SIZE, 1))
fake_images_labels = -real_images_labels
for epoch in range(epochs):
for batch in range(NUM_BACHES):
for critic_iter in range(n_critic):
random_batches_idx = np.random.randint(0, NUM_BACHES) # Choose random batch from dataset
imgs_data=dataset_list[random_batches_idx]
c_loss_real = critic.train_on_batch(imgs_data, real_images_labels) # update the weights after 1 batch
noise = tf.random.normal([imgs_data.shape[0], noise_dim]) # Generate noise data
generated_images = generator(noise, training=True)
c_loss_fake = critic.train_on_batch(generated_images, fake_images_labels) # update the weights after 1 batch
imgs_data=dataset_list[batch]
noise = tf.random.normal([imgs_data.shape[0], noise_dim]) # Generate noise data
gen_loss_batch = gen_loss_batch + gan.train_on_batch(noise,real_images_labels)
Обучение занимает у меня много времени. В эпоху около 3м. Идея, что мне пришлось сократить время обучения, заключается в том, чтобы вместо этого выполнять каждую серию вперед n_criti c раз. Я могу увеличить размер batch_size для дискриминатора и запустить его один раз вперед с большим размером batch_size.
Что вы, ребята, думаете ? Звучит разумно?
- Я не вставил весь свой код, это была только его часть ..