Keras множественный вход, выход, модель потерь - PullRequest
0 голосов
/ 01 октября 2019

Я работаю над суперразрешением GAN и у меня есть некоторые сомнения по поводу кода, который я нашел на Github. В частности, у меня есть несколько входов, несколько выходов в модели. Кроме того, у меня есть две разные функции потерь.

В следующем коде потери mse будут применены к img_hr и fake_features?

# Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # High res. and low res. images
        img_hr = Input(shape=self.hr_shape)
        img_lr = Input(shape=self.lr_shape)

        # Generate high res. version from low res.
        fake_hr = self.generator(img_lr)

        # Extract image features of the generated img
        fake_features = self.vgg(fake_hr)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

# Discriminator determines validity of generated high res. images
        validity = self.discriminator(fake_hr)

        self.combined = Model([img_lr, img_hr], [validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)

Ответы [ 2 ]

0 голосов
/ 02 октября 2019

В нейронных сетях Loss применяется к выходам сети, чтобы иметь способ измерения «Насколько неправильный этот выход?»так что вы можете принять это значение и минимизировать его с помощью градиента приличного и обратного. Следуя этой интуиции, потери в кератах - это список той же длины, что и выходы вашей модели. Они применяются к выходу с тем же индексом.

self.combined = Model([img_lr, img_hr], [validity, fake_features])

Это дает вам модель с 2 входами (img_lr, img_hr) и 2 выходами (validity, fake_features). Таким образом, combined.compile(loss=['binary_crossentropy', 'mse']... использует потерю binary_crossentropy для достоверности и Mean Squared Error для fake_features.

0 голосов
/ 02 октября 2019

В следующем коде будет применена потеря mse к img_hr и fake_features?

Из документации https://keras.io/models/model/#compile

" Если модельимеет несколько выходов, вы можете использовать разные потери на каждом выходе, передав словарь или список потерь."

В этом случае потеря mse будет применена к fake_features и соответствующему y_true переданномукак часть self.combined.fit().

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...