Как правильно сохранить модель, чтобы продолжить обучение ВАЭ в керасе - PullRequest
1 голос
/ 04 августа 2020

Я построил VAE в keras, используя функциональный API. VAE имеет 3 модели:

  1. энкодер
def _create_encoder(self):
        # create convolutional layers for encoder
        X = self.images
        for i in range(len(self.encoder_filters)):
            X = self._create_conv_layer(X,
                                        "Conv2D",
                                        self.encoder_filters[i],
                                        self.encoder_kernel_size[i],
                                        self.encoder_strides[i],
                                        self.encoder_padding[i],
                                        "encoder_conv"+str(i)
                                        )
        # keep track of tensor shape before flattening (we will need this to build decoder)
        encoder_shape_before_flattening = K.int_shape(X)[1:]
        # flatten the tensor
        X = Flatten()(X)
        # create dense layers for mu and sigma
        self.encoder_mu = Dense(units=self.latent_space_size, name='encoder_mu')(X)
        self.encoder_log_var = Dense(units=self.latent_space_size, name='encoder_log_var')(X)
        self.encoder_parameters = Model(self.images, (self.encoder_mu, self.encoder_log_var))
        # create encoder output by sampling from normal distribution
        self.encoder_output = Lambda(self.sample_latent_space,name="encoder_output")([self.encoder_mu,self.encoder_log_var])
        self.encoder = Model(inputs=self.images, outputs=self.encoder_output)
        return encoder_shape_before_flattening

декодер
 def _create_decoder(self, encoder_shape_before_flattening):
        X = Dense(np.prod(encoder_shape_before_flattening))(self.decoder_input)
        X = Reshape(encoder_shape_before_flattening)(X)
        # create convolutional layers for decoder
        for i in range(len(self.decoder_filters)):
            is_not_last_layer = i < len(self.decoder_filters)-1
            X = self._create_conv_layer(X,
                                        "Conv2DTranspose",
                                        self.decoder_filters[i],
                                        self.decoder_kernel_size[i],
                                        self.decoder_strides[i],
                                        self.decoder_padding[i],
                                        "decoder_conv"+str(i),
                                        batch_norm=is_not_last_layer,
                                        dropout=is_not_last_layer,
                                        activation=is_not_last_layer
                                        )
        # output values should be between 0 and 1
        self.decoder_output = Activation("sigmoid")(X)
        self.decoder = Model(inputs=self.decoder_input, outputs=self.decoder_output)
вся модель
def _create_model(self):
        self.images = Input(shape=self.input_dims, name="images")
        # create encoder as separate model      
        encoder_shape_before_flattening = self._create_encoder()
        # create decoder as separate model
        self.decoder_input = Input(shape=(self.latent_space_size,), name="decoder_input")
        self._create_decoder(encoder_shape_before_flattening)
        # create unique model 
        self.model = Model(inputs=self.images, outputs=self.decoder(self.encoder_output))

Я использую обратный вызов ModelCheckpoint для сохранения всей модели после каждой эпохи.

checkpoint_model = ModelCheckpoint(os.path.join(save_path, "model.h5"), verbose=1)

Но когда я загружаю модель с load_model

def load_trained_model(self, load_path, r_loss_factor):
        self.model = load_model(os.path.join(load_path, "model.h5"), custom_objects={"loss": self.penalized_loss(r_loss_factor),"sample_latent_space":self.sample_latent_space})

и снова вызовите fit_generator, чтобы продолжить обучение. Я получаю следующую ошибку:

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: You must feed a value for placeholder tensor 'images' with dtype float and shape [?,128,128,3]
     [[{{node images}}]]
     [[metrics_1/loss_1/Identity/_1971]]
  (1) Invalid argument: You must feed a value for placeholder tensor 'images' with dtype float and shape [?,128,128,3]
     [[{{node images}}]]

Код можно найти здесь

...