Замораживание дискриминаторных слоев в GAN - PullRequest
0 голосов
/ 02 февраля 2020

Я пытаюсь реализовать SimGAN в Керасе, но я считаю, что вопрос в более общем плане связан с GAN и замораживающими слоями.

Насколько я понимаю, мне нужны три модели:

  1. Улучшитель , который обрабатывает синтетическое c изображение, чтобы сделать его более реалистичным c. (В других архитектурах GAN это может быть генератор , который принимает случайный шум.)

  2. дискриминатор , который обрабатывает изображение и классифицирует его как syntheti c или вещественное.

  3. Комбинированная модель, которая передает синтезированное c изображение через рафинер, а затем очищенное изображение в дискриминатор.

В объединенной модели мы хотим, чтобы слои дискриминатора были заморожены , чтобы при обучении объединенной модели мы обновляли только слои уточнения.

Отдельно мы обучаем модель только для дискриминатора, где очевидно, что слои не должны быть заморожены. Слои модели дискриминатора и объединенной модели должны иметь общие веса, чтобы оба они обновлялись.

Вот что у меня получилось:

refiner_model = make_refiner_model(input_shape=(img_height, img_width, img_channels))
discriminator_model = make_discriminator_model(input_shape=refiner_model.output_shape[1:])

# create combined model with frozen discriminator layers
synthetic_image_tensor = layers.Input(refiner_model.input_shape[1:])
refiner_model_output = refiner_model(synthetic_image_tensor)
combined_output = discriminator_model(refiner_model_output)

combined_model = models.Model(
    inputs=synthetic_image_tensor,
    outputs=[refiner_model_output, combined_output],
    name='combined'
)

Как заморозить слои дискриминатора в комбинированная модель без замораживания их в модели только для дискриминатора?


В FAQ по Keras прямо предлагается следующее:

refiner_model.compile(...)
discriminator_model.compile(...)

discriminator_model.trainable = False
combined_model.compile(...)

Но потом, когда я распечатаю discriminator_model.summary(), число параметров удвоилось?

Total params: 151,812
Trainable params: 75,906
Non-trainable params: 75,906

Тогда я получу предупреждений об изменении .trainable без перекомпиляции , и в итоге произойдет сбой с эта ошибка .

...