Управление потерей керас с внешними данными - PullRequest
0 голосов
/ 24 апреля 2020

Предположим, что у нас есть модель model_A, и мы хотим построить обратное распространение на основе 3 различных функций потерь. Первая потеря (Loss_1) должна основываться на выводе model_A, Loss_2 и Loss_3 могут быть получены из чего-то еще. Думайте об этом как об отклонении от неизвестного источника, как в автоматизации процессов, если вы хотите создать свой ПИД-контроллер. Самый простой способ - это мой подход, но на самом деле он терпит неудачу, потому что граф построен не так, как я хочу, потому что X_realB и X_real C не имеют связи с model_A и игнорируются керасами с бэкэндом tenorflow.

Основной вопрос: Как я могу передать новые переменные в моей функции потерь, не обрабатывая их в модели и не влияя на проблему минимизации?

def generator_model(model_A):

  model_A.trainable = True

# import
  X_realA = Input(shape=image_shape)
  X_realB = Input(shape=image_shape)
  X_realC = Input(shape=image_shape)

# generate Fake image
  Fake_A=model_A(X_realA)


  model = Model([X_realA],[Fake_A,X_realB ,X_realC])

  opt = Adam(lr=0.0002, beta_1=0.5)
  model.compile(loss=["mse","mse","mse"],loss_weights=[1,1,1], optimizer=opt)
  model.summary()
  return model

I пробовал что-то еще в последние 2 дня. Обтекание [FakeA,B,C] в пользовательском лямбда-слое для расчета комбинированных потерь (одно выходное значение этого пользовательского слоя). Чем передать эту потерю, в фиктивной пользовательской функции потерь, которая просто выводит объединенное значение лямбда-слоя. Вот пример:

    # import A,B,C and than pass A into Generator .... and after that:

    combined_loss= Lambda(lambda x: combined_loss_func(x))([FakeA,B,C])

    model=Model([A,B,C],[combined_loss],loss=dummy_loss)

    def dummy_loss(y_pred,y_true):
      return y_pred


`combined_loss` could look like that:

    def combined_loss_func(x):

      FakeA,B,C=x[0],x[1],x[2]

      # transform all inputs into one row-tensors
      shape=tf.shape(FakeA)
      FakeA=tf.reshape(FakeA,[1,shape[0]*shape[1]*shape[2]*shape[3]])   
      shape=tf.shape(B)
      B=tf.reshape(B,[1,shape[0]*shape[1]*shape[2]*shape[3]]) 
      shape=tf.shape(C)
      C=tf.reshape(C,[1,shape[0]*shape[1]*shape[2]*shape[3]]) 

      # build up a hypothetical ground truth
      FakeA_ones=tf.ones_like(FakeA)
      A_ones=tf.ones_like(A)
      B_ones=tf.ones_like(B)

      # calculate losses
      loss0=keras.losses.mse(FakeA,FakeA_ones)
      loss1=keras.losses.mse(A,A_ones)
      loss2=keras.losses.mse(B,B_ones)

      # sum them up
      summe=tf.math.add(loss0,loss1)
      summe=tf.math.add(summe,loss2)

      # average them
      avg=tf.math.truediv(summe,3.0)
      avg=tf.expand_dims(summe,axis=-1)

      return avg

Если я сейчас попытаюсь установить нулевую потерю FakeA, обратного распространения на modelA больше не произойдет, или, по крайней мере, ничего в системе больше не изменится:

       # calculate losses
      loss0=keras.losses.mse(FakeA,FakeA_ones) * 0
      loss1=keras.losses.mse(A,A_ones)
      loss2=keras.losses.mse(B,B_ones)

Сначала это кажется действительно хорошим, но когда я go перешел в пользовательскую функцию, а не использовал FakeA, который является единственным и единственным тензором, прошедшим через генератор. После того, как я получу значение для моей функции потерь, которая кажется точной, но на самом деле ничего не происходит, мой цикл Ган вообще не улучшается, и все изображения прошли, хотя все еще выглядят одинаково, даже после 100 эпох.

Не означает ли это, что другие потери даже не учитываются, и я просто использую убыток1 по сравнению с FakeA, и я просто получаю разные результаты из-за типичного расхождения системы GAN динамически c?

PS: Я знаю, что за алгеброй и математикой стоит идея backprop и все, и, следовательно, умножение на 0 делает градиенты нулевыми. Но разве нет способа манипулировать окончательным значением потерь, которое, как правило, сводится к минимуму, но все же удерживает градиенты до modelA, как раньше?

Спасибо за ваше время!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...