Я тренировал условную архитектуру GAN, аналогичную Pix2Pix, со следующим циклом обучения:
for epoch in range(start_epoch, end_epoch):
for batch_i, (input_batch, target_batch) in enumerate(dataLoader.load_batch(batch_size)):
fake_batch= self.generator.predict(input_batch)
d_loss_real = self.discriminator.train_on_batch(target_batch, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_batch, invalid)
d_loss = np.add(d_loss_fake, d_loss_real) * 0.5
g_loss = self.combined.train_on_batch([target_batch, input_batch], [valid, target_batch])
Теперь это работает хорошо, но не очень эффективно, поскольку загрузчик данных быстро становится узким местом. мудрый. Я изучил функцию .fit_generator (), которую предоставляет keras, которая позволяет генератору работать в рабочем потоке и работает намного быстрее.
self.combined.fit_generator(generator=trainLoader,
validation_data=evalLoader
callbacks=[checkpointCallback, historyCallback],
workers=1,
use_multiprocessing=True)
Мне потребовалось некоторое время, чтобы понять, что это неправильно, я больше не тренировал свой генератор и дискриминатор отдельно, а дискриминатор не тренировался вообще, так как он установлен на trainable = False
в комбинированной модели, по сути, разрушающей любой вид состязательной потери, и я мог бы также обучить свой генератор самостоятельно с помощью MSE
.
Теперь мой вопрос заключается в том, есть ли какая-то работа вокруг, такая как тренировка моего дискриминаторав пользовательском обратном вызове, который запускается в каждом пакете метода .fit_generator ()? Можно реализовать создание пользовательских обратных вызовов, например, вот так:
class MyCustomCallback(tf.keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
discriminator.train_on_batch()
Другая возможность - распараллелить исходный цикл обучения, но я боюсь, что у меня нет времени, чтобы сделать это прямо сейчас.