Почему моя модель работает с `tf.GradientTape ()`, но не работает при использовании `keras.models.Model.fit ()` - PullRequest
1 голос
/ 05 марта 2020

После долгих усилий мне удалось построить tensorflow 2 реализацию существующего проекта pytorch передачи стилей. Затем я хотел получить все приятные дополнительные функции, которые доступны через стандартное обучение Keras, например model.fit().

Но та же модель терпит неудачу при обучении через model.fit(). Модель, кажется, изучает функции контента, но не может изучить функции стиля. Это диаграмма модели в вопросе:

enter image description here

def vgg_layers19(content_layers, style_layers, input_shape=(256,256,3)):
  """ creates a VGG model that returns output values for the given layers
  see: https://keras.io/applications/#extract-features-from-an-arbitrary-intermediate-layer-with-vgg19
  Returns: 
    function(x, preprocess=True):
      Args: 
        x: image tuple/ndarray h,w,c(RGB), domain=(0.,255.)
      Returns:
        a tuple of lists, ([content_features], [style_features])
  usage:
    (content_features, style_features) = vgg_layers16(content_layers, style_layers)(x_train)
  """
  preprocessingFn = tf.keras.applications.vgg19.preprocess_input
  base_model = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
  base_model.trainable = False
  content_features = [base_model.get_layer(name).output for name in content_layers]
  style_features = [base_model.get_layer(name).output for name in style_layers]
  output_features = content_features + style_features

  model = Model( inputs=base_model.input, outputs=output_features, name="vgg_layers")
  model.trainable = False

  def _get_features(x, preprocess=True):
    """
    Args:
      x: expecting tensor, domain=255. hwcRGB
    """
    if preprocess and callable(preprocessingFn): 
      x = preprocessingFn(x)
    output = model(x) # call as tf.keras.Layer()
    return ( output[:len(content_layers)], output[len(content_layers):] )

  return _get_features 



class VGG_Features():
""" get content and style features from VGG model """
  def __init__(self, loss_model, style_image=None, target_style_gram=None):
    self.loss_model = loss_model
    if style_image is not None:
      assert style_image.shape == (256,256,3), "ERROR: loss_model expecting input_shape=(256,256,3), got {}".format(style_image.shape)
      self.style_image = style_image
      self.target_style_gram = VGG_Features.get_style_gram(self.loss_model, self.style_image)
    if target_style_gram is not None:
      self.target_style_gram = target_style_gram

  @staticmethod
  def get_style_gram(vgg_features_model, style_image):
    style_batch = tf.repeat( style_image[tf.newaxis,...], repeats=_batch_size, axis=0)
    # show([style_image], w=128, domain=(0.,255.) )

    # B, H, W, C = style_batch.shape
    (_, style_features) = vgg_features_model( style_batch , preprocess=True ) # hwcRGB
    target_style_gram = [ fnstf_utils.gram(value)  for value in style_features ]  # list
    return target_style_gram  

  def __call__(self, input_batch):
    content_features, style_features = self.loss_model( input_batch, preprocess=True )
    style_gram = tuple(fnstf_utils.gram(value)  for value in style_features)  # tuple(<generator>)
    return (content_features[0],) + style_gram  # tuple = tuple + tuple




class TransformerNetwork_VGG(tf.keras.Model):
  def __init__(self, transformer=transformer, vgg_features=vgg_features):
    super(TransformerNetwork_VGG, self).__init__()
    self.transformer = transformer 
    # type: tf.keras.models.Model
    # input_shapes:  (None, 256,256,3)
    # output_shapes: (None, 256,256,3)


    style_model = {
       'content_layers':['block5_conv2'],
       'style_layers': ['block1_conv1',
                  'block2_conv1',
                  'block3_conv1', 
                  'block4_conv1', 
                  'block5_conv1']
    }
    vgg_model = vgg_layers19( style_model['content_layers'], style_model['style_layers'] )

    self.vgg_features = VGG_Features(vgg_model, style_image=style_image, batch_size=batch_size) 

    # input_shapes:  (None, 256,256,3)
    # output_shapes: [(None, 16, 16, 512),  (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
    #                [ content_loss,        style_loss_1, style_loss_2, style_loss_3, style_loss_4, style_loss_5 ]


  def call(self, inputs):
    x = inputs                # shape=(None, 256,256,3)

    # shape=(None, 256,256,3)
    generated_image = self.transformer(x)                    

    # shape=[(None, 16, 16, 512),  (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
    vgg_feature_losses = self.vgg(generated_image)           

    return vgg_feature_losses       # tuple(content1, style1, style2, style3, style4, style5)

Стиль изображения style image

FEATURE_WEIGHTS = [1.0, 1,0, 1,0, 1,0, 1,0, 1,0]

GradientTape обучение

С tf.GradientTape() l oop я вручную обрабатываю несколько выходов, например, набор из 6 тензоров от TransformerNetwork_VGG(x_train). Этот метод обучается правильно.

  @tf.function()
  def train_step(x_train, y_true, loss_weights=None, log_freq=10):
    with tf.GradientTape() as tape:
      y_pred = TransformerNetwork_VGG(x_train)
      generated_content_features = y_pred[:1]
      generated_style_gram = y_pred[1:]


      y_true = TransformerNetwork_VGG.vgg(x_train)
      target_content_features = y_true[:1]
      target_style_gram = TransformerNetwork_VGG.vgg.target_style_gram

      content_loss = get_MEAN_mse_loss(target_content_features, generated_content_features, weights)
      style_loss = tuple(get_MEAN_mse_loss(x,y)*w for x,y,w in zip(target_style_gram, generated_style_gram, weights))

      total_loss = content_loss + = tf.reduce_sum(style_loss)
      TransformerNetwork = TransformerNetwork_VGG.transformer
      grads = tape.gradient(total_loss, TransformerNetwork.trainable_weights)
      optimizer.apply_gradients(zip(grads, TransformerNetwork.trainable_weights))
# GradientTape epoch=5: 
# losses:             [   6078.71         70.23  4495.13 13817.65 88217.99    48.36]

gradient tape

model.fit() learning

С tf.keras.models.Model.fit(), несколько выходов, например, набор из 6 тензоров , подаются на функцию потерь индивидуально как loss(y_pred, y_true), а затем умножаются на правильный вес на reduction. Этот метод учится приближаться к content_image, но не учится , чтобы минимизировать потери стиля! Я не могу понять, почему.

  history = TransformerNetwork_VGG.fit(
    x=train_dataset.repeat(NUM_EPOCHS),
    epochs=NUM_EPOCHS,
    steps_per_epoch=NUM_BATCHES,
    callbacks=callbacks,
  )
# model.fit() epoch=5: 
# losses:             [  4661.08       219.95   6959.01   4897.39 209201.16     84.68]]

model-fit

50 эпох, с увеличенными style_weights, FEATURE_WEIGHTS = [0.1854, 1605.23, 25.08, 8.16, 1.28, 2330.79] # потеря стиля буста x100

model-fit after 50

step = 50, потери = [269899.45 337.5 69617.7 38424.96 9192.36 85903.44 66423.51]

проверка mse потери * веса

Я проверил свою модель с потерями и весами, зафиксированными следующим образом * FEATURE_WEIGHTS = SEQ = [1., 2., 3., 4., 5., 6., ] * MSELoss (y_true, y_pred) == tf.ones () равной формы и подтвердил, что model.fit() правильно обрабатывает несколько выходных потерь * весов

losses as ones

У меня есть проверил все, что могу придумать, но я не могу понять, как заставить модель учиться правильно с model.fit(). Чего мне не хватает ??

Полный блокнот доступен здесь: https://colab.research.google.com/github/mixuala/fast_neural_style_pytorch/blob/master/notebook/%5BSO%5D_FastStyleTransfer.ipynb

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...