Керас / тензор потока VAE странное поведение с функцией потерь - PullRequest
0 голосов
/ 23 февраля 2019

Я использую Keras для моделирования VAE.Моя полная записная книжка доступна в https://gist.github.com/v-i-s-h/fdabcb3d85b89ade95758bd014b307f1

Код для обучения модели с помощью только регистрации окончательных значений потерь (сумма потерь при реконструкции и потерь в KL), у меня есть следующий блок

vae = Model( inputs, outputs, name = "vae" )

# VAE loss = mse_loss OR kl_loss+xent_loss
# reconLoss = mse( K.flatten(inputs), K.flatten(outputs) )
reconLoss = binary_crossentropy( K.flatten(inputs), K.flatten(outputs) )
reconLoss *= imageSize*imageSize # Because binary_crossentropy divides by N
klLoss = 1 + zLogVar - K.square(zMean) - K.exp(zLogVar)
klLoss = K.sum( klLoss, axis = -1 )
klLoss *= -0.5

vaeLoss = K.mean( reconLoss + klLoss )
vae.add_loss( vaeLoss )
vae.compile( optimizer = 'adam' )
vae.summary()
plot_model( vae, to_file = 'vae_model.png', show_shapes = True )


# Callback
import tensorflow.keras.callbacks as cb
class PlotResults( cb.Callback ):
  def __init__( self, models, data, batch_size, model_name ):
    self.models = models
    self.data = data
    self.batchSize = batch_size
    self.model_name = model_name
    self.epochCount = 0
    super().__init__()
  def on_train_begin( self, log = {} ):
    self.epochCount = 0
    plot_results( models, data, batch_size = self.batchSize, epochCount = self.epochCount )

  def on_epoch_end( self, batch, logs = {} ):
#     print( logs )
    self.epochCount += 1
    plot_results( self.models, self.data, batch_size = self.batchSize, epochCount = self.epochCount )

cbPlotResults = PlotResults( models, data, batchSize, "." )

trainLog = vae.fit( xTrain,
           epochs = epochs,
           batch_size = batchSize,
           validation_data = (xTest,None),
           callbacks = [cbPlotResults] )

При этом модель, по-видимому, обучается (см. Графики в связанной записной книжке: https://gist.github.com/v-i-s-h/fdabcb3d85b89ade95758bd014b307f1), и все работает так, как ожидалось.

Теперь я хочу отслеживать функцию потери отдельных реконструкцийа также потери в kl-div во время обучения. Для этого код был изменен на

vae2 = Model( inputs, outputs, name = "vae2" )

# ======================  CHANGE   ==================================
def fn_reconLoss( x, x_hat ):
  # reconLoss = mse( K.flatten(inputs), K.flatten(outputs) )
  reconLoss = binary_crossentropy( K.flatten(x), K.flatten(x_hat) )
  reconLoss *= imageSize*imageSize # Because binary_crossentropy divides by N
  return reconLoss

def fn_klLoss( x, x_hat ):
  klLoss = 1 + zLogVar - K.square(zMean) - K.exp(zLogVar)
  klLoss = K.sum( klLoss, axis = -1 )
  klLoss *= -0.5
  return klLoss

def fn_vaeloss( x, x_hat ):
  return K.mean(fn_reconLoss(x,x_hat) + fn_klLoss(x,x_hat))
# ====================================================================
# vae2.add_loss( fn_vaeloss )
vae2.compile( optimizer = 'adam', loss=fn_vaeloss, metrics = [fn_reconLoss,fn_klLoss] )
vae2.summary()
plot_model( vae2, to_file = 'vae2_model.png', show_shapes = True )


# Callback
import tensorflow.keras.callbacks as cb
class PlotResults( cb.Callback ):
  def __init__( self, models, data, batch_size, model_name ):
    self.models = models
    self.data = data
    self.batchSize = batch_size
    self.model_name = model_name
    self.epochCount = 0
    super().__init__()
  def on_train_begin( self, log = {} ):
    self.epochCount = 0
    plot_results( models, data, batch_size = self.batchSize, epochCount = self.epochCount )

  def on_epoch_end( self, batch, logs = {} ):
#     print( logs )
    self.epochCount += 1
    plot_results( self.models, self.data, batch_size = self.batchSize, epochCount = self.epochCount )

cbPlotResults = PlotResults( models, data, batchSize, "." )

trainLog = vae2.fit( xTrain,
           epochs = epochs,
           batch_size = batchSize,
           validation_data = (xTest,xTest),
           callbacks = [cbPlotResults] )

При этом модель учится неправильно. Даже если потеря кажется уменьшенной, реконструкции совершенно бесполезны(См. https://gist.github.com/v-i-s-h/fdabcb3d85b89ade95758bd014b307f1 последний блок.)

Я не могу понять, где ошибка при определении функций потерь во втором блоке. Это правильный способ сделать это?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...