Как написать пользовательскую функцию потерь keras, когда вам нужно ввести значение для расчета потерь? - PullRequest
1 голос
/ 20 февраля 2020

perceptual losses

Я пытаюсь продублировать переводную бумагу быстрого стиля (см. Схему выше), используя метод, описанный в Встроенные циклы обучения и оценки keras

У меня проблемы с пониманием, как это сделать с помощью специального класса потерь (см. Ниже).

Для расчета компонентов потерь мне нужно следующее:

  • y_hat, сгенерированное изображение для получения
(generated_content_features, generated_style_features) = VGG(y_hat)
generated_style_gram = [ utils.gram(value) for value in generated_style_features ]
  • target_style_gram, что означает c, поэтому я могу получить один раз из target_style_features и кеша, (_,target_style_features) = VGG(y_s)
  • x, InputImage (такой же, как y_c ContentTarget), чтобы получить (target_content_features, _) = VGG(x)

Я обнаружил, что я исправляю множество вещей в классе потерь, tf.keras.losses.Loss, чтобы получить эти значения и в конечном итоге выполнить расчет потерь. Это особенно верно для target_content_features, для которого требуется входное изображение , что я передал через y_true, но это, очевидно, хак

y_pred = generated_image # y_hat from diagram, shape=(b,256,256,3)
y_true = x # hack: access the input image here

lossFn = PerceptualLosses_Loss(VGG, target_style_gram)
loss = lossFn(y_true, y_pred)


class PerceptualLosses_Loss(tf.losses.Loss):
  name="PerceptualLosses_Loss"
  reduction=tf.keras.losses.Reduction.AUTO
  RGB_MEAN_NORMAL_VGG = tf.constant( [0.48501961, 0.45795686, 0.40760392], dtype=tf.float32)

  def __init__(self, loss_network, target_style_gram, loss_weights=None):
    super(PerceptualLosses_Loss, self).__init__( name=self.name, reduction=self.reduction )
    self.target_style_gram = target_style_gram # repeated in y_true
    print("PerceptualLosses_Loss init()", type(target_style_gram), type(self.target_style_gram))
    self.VGG = loss_network

  def call(self, y_true, y_pred):

    b,h,w,c = y_pred.shape
    #???: y_pred.shape=(None, 256,256,3), need batch dim for utils.gram(value)
    generated_batch = tf.reshape(y_pred, (BATCH_SIZE,h,w,c) )

    # generated_batch: expecting domain=(+-int), mean centered
    generated_batch = tf.nn.tanh(generated_batch) # domain=(-1.,1.), mean centered

    # reverse VGG mean_center
    generated_batch = tf.add( generated_batch, self.RGB_MEAN_NORMAL_VGG) # domain=(0.,1.)
    generated_batch_BGR_centered = tf.keras.applications.vgg19.preprocess_input(generated_batch*255.)/255.
    generated_content_features, generated_style_features = self.VGG( generated_batch_BGR_centered, preprocess=False )
    generated_style_gram = [ utils.gram(value)  for value in generated_style_features ]  # list

    y_pred = generated_content_features + generated_style_gram
    # print("PerceptualLosses_Loss: y_pred, output_shapes=", type(y_pred), [v.shape for v in y_pred])
    # PerceptualLosses_Loss: y_pred, output_shapes= [
    #   TensorShape([4, 16, 16, 512]), 
    #   TensorShape([4, 64, 64]), 
    #   TensorShape([4, 128, 128]), 
    #   TensorShape([4, 256, 256]), 
    #   TensorShape([4, 512, 512]), 
    #   TensorShape([4, 512, 512])
    # ]

    if tf.is_tensor(y_true):
      # print("detect y_true is image", type(y_true), y_true.shape)
      x_train = y_true
      x_train_BGR_centered = tf.keras.applications.vgg19.preprocess_input(x_train*255.)/255.
      target_content_features, _ = self.VGG(x_train_BGR_centered, preprocess=False )
      # ???: target_content_features[0].shape=(None, None, None, 512), should be shape=(4, 16, 16, 512)
      target_content_features = [tf.reshape(v, generated_content_features[i].shape) for i,v in enumerate(target_content_features)]
    elif isinstance(y_true, tuple):
      print("detect y_true is tuple(target_content_features + self.target_style_gram)", y_true[0].shape)
      target_content_features = y_true[:len(generated_content_features)]
      if self.target_style_gram is None:
        self.target_style_gram = y_true[len(generated_content_features):]
    else:
      assert False, "unexpected result for y_true"

    # losses = tf.keras.losses.MSE(y_true, y_pred)
    def batch_reduce_sum(y_true, y_pred, weight, name):
      losses = tf.zeros(BATCH_SIZE)
      for a,b in zip(y_true, y_pred):
        # batch_reduce_sum()
        loss = tf.keras.losses.MSE(a,b)
        loss = tf.reduce_sum(loss, axis=[i for i in range(1,len(loss.shape))] )
        losses = tf.add(losses, loss)
      return tf.multiply(losses, weight, name="{}_loss".format(name)) # shape=(BATCH_SIZE,)

    c_loss = batch_reduce_sum(target_content_features, generated_content_features, CONTENT_WEIGHT, 'content_loss')
    s_loss = batch_reduce_sum(self.target_style_gram, generated_style_gram, STYLE_WEIGHT, 'style_loss')
    return (c_loss, s_loss)

I также пытался предварительно вычислить y_true в tf.data.Dataset, но, хотя он работал нормально под eager execution, он вызывал ошибку во время model.fit()

xy_true_Dataset = tf.data.Dataset.from_generator(
    xyGenerator_y_true(image_ds, VGG, target_style_gram),
    output_types=(tf.float32, (tf.float32,  tf.float32,tf.float32,tf.float32,tf.float32,tf.float32) ),
    output_shapes=(
      (256,256,3),
      ( (16, 16, 512), (64, 64), (128, 128), (256, 256), (512, 512), (512, 512)) 
    ),
  )

# eager execution, y_true: <class 'tuple'> [TensorShape([4, 16, 16, 512]), TensorShape([4, 64, 64]), TensorShape([4, 128, 128]), TensorShape([4, 256, 256]), TensorShape([4, 512, 512]), TensorShape([4, 512, 512])]
# model.fit(), y_true: <class 'tensorflow.python.framework.ops.Tensor'> (None, None, None, None)

ValueError: Error when checking model target: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 array(s), for inputs ['output_1'] but instead got the following list of 6 arrays: [<tf.Tensor 'args_1:0' shape=(None, 16, 16, 512) dtype=float32>, <tf.Tensor 'args_2:0' shape=(None, 64, 64) dtype=float32>, <tf.Tensor 'args_3:0' shape=(None, 128, 128) dtype=float32>, <tf.Tensor 'arg...

Есть ли у меня неправильный подход к эта проблема?

...