Как сделать пользовательскую функцию потерь, которая использует предыдущий вывод из сети в Керасе? - PullRequest
0 голосов
/ 24 июня 2019

Я пытаюсь создать пользовательскую функцию потерь, которая берет предыдущий вывод (вывод предыдущей итерации) из сети и использует его с текущим выводом.

Вот что я пытаюсь сделать, но я не знаю, как это сделать

def l_loss(prev_output):

    def loss(y_true, y_pred):

        pix_loss = K.mean(K.square(y_pred - y_true), axis=-1)

        pase = K.variable(100)

        diff = K.mean(K.abs(prev_output - y_pred))
        movement_loss = K.abs(pase - diff)
        total_loss = pix_loss + movement_loss

        return total_loss
    return loss

self.model.compile(optimizer=Adam(0.001, beta_1=0.5, beta_2=0.9),
 loss=l_loss(?))

Надеюсь, вы мне поможете.

1 Ответ

1 голос
/ 25 июня 2019

Вот что я пробовал:

from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import Sequential
from tensorflow.keras import backend as K

class MovementLoss(object):
  def __init__(self):
    self.var = None

  def __call__(self, y_true, y_pred, sample_weight=None):
    mse = K.mean(K.square(y_true - y_pred), axis=-1)
    if self.var is None:
      z = np.zeros((32,))
      self.var = K.variable(z)
    delta = K.update(self.var, mse - self.var)
    return mse + delta


def make_model():
  model = Sequential()
  model.add(Dense(1, input_shape=(4,)))
  loss = MovementLoss()
  model.compile('adam', loss)
  return model

model = make_model()
model.summary()


Использование примера тестовых данных.

import numpy as np

X = np.random.rand(32, 4)

POLY = [1.0, 2.0, 0.5, 3.0]
def test_fn(xi):
  return np.dot(xi, POLY)

Y = np.apply_along_axis(test_fn, 1, X)

history = model.fit(X, Y, epochs=4)

Я вижу, что функция потерь колеблется, как мне кажется, под влиянием последней дельты пакета. Обратите внимание, что данные функции потери не соответствуют вашей заявке.

Важнейшим шагом является то, что шаг K.update должен быть частью графика (насколько я понимаю).

Что достигается:

delta = K.update(var, delta)
return x + delta
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...