Функция потери Кераса с дополнительным динамическим параметром - PullRequest
0 голосов
/ 02 мая 2018

Я работаю над реализацией приоритетного воспроизведения опыта для сети с глубоким q, и часть спецификации заключается в умножении градиентов на так называемые весовые коэффициенты выборки (IS). Модификация градиента обсуждается в разделе 3.4 следующего документа: https://arxiv.org/pdf/1511.05952.pdf Я пытаюсь создать пользовательскую функцию потерь, которая принимает массив весов IS в дополнение к y_true и y_pred. * 1005. *

Вот упрощенная версия моей модели:

import numpy as np
import tensorflow as tf

# Input is RAM, each byte in the range of [0, 255].
in_obs = tf.keras.layers.Input(shape=(4,))

# Normalize the observation to the range of [0, 1].
norm = tf.keras.layers.Lambda(lambda x: x / 255.0)(in_obs)

# Hidden layers.
dense1 = tf.keras.layers.Dense(128, activation="relu")(norm)
dense2 = tf.keras.layers.Dense(128, activation="relu")(dense1)
dense3 = tf.keras.layers.Dense(128, activation="relu")(dense2)
dense4 = tf.keras.layers.Dense(128, activation="relu")(dense3)

# Output prediction, which is an action to take.
out_pred = tf.keras.layers.Dense(2, activation="linear")(dense4)

opt     = tf.keras.optimizers.Adam(lr=5e-5)
network = tf.keras.models.Model(inputs=in_obs, outputs=out_pred)
network.compile(optimizer=opt, loss=huber_loss_mean_weighted)

Вот моя пользовательская функция потерь, которая является просто реализацией Huber Loss, умноженной на весовые коэффициенты IS:

'''
 ' Huber loss: https://en.wikipedia.org/wiki/Huber_loss
'''
def huber_loss(y_true, y_pred):
  error = y_true - y_pred
  cond  = tf.keras.backend.abs(error) < 1.0

  squared_loss = 0.5 * tf.keras.backend.square(error)
  linear_loss  = tf.keras.backend.abs(error) - 0.5

  return tf.where(cond, squared_loss, linear_loss)

'''
 ' Importance Sampling weighted huber loss.
'''
def huber_loss_mean_weighted(y_true, y_pred, is_weights):
  error = huber_loss(y_true, y_pred)

  return tf.keras.backend.mean(error * is_weights)

Важным битом является то, что is_weights является динамическим, то есть он меняется каждый раз, когда вызывается fit(). Поэтому я не могу просто закрыть is_weights, как описано здесь: Сделать пользовательскую функцию потерь в кератах

Я нашел этот код в Интернете, который, кажется, использует слой Lambda для расчета потерь: https://github.com/keras-team/keras/blob/master/examples/image_ocr.py#L475 Это выглядит многообещающе, но я изо всех сил пытаюсь понять его / адаптировать к моей конкретной проблеме. Любая помощь приветствуется.

1 Ответ

0 голосов
/ 02 мая 2018

OK. Вот пример.

from keras.layers import Input, Dense, Conv2D, MaxPool2D, Flatten
from keras.models import Model
from keras.losses import categorical_crossentropy

def sample_loss( y_true, y_pred, is_weight ) :
    return is_weight * categorical_crossentropy( y_true, y_pred ) 

x = Input(shape=(32,32,3), name='image_in')
y_true = Input( shape=(10,), name='y_true' )
is_weight = Input(shape=(1,), name='is_weight')
f = Conv2D(16,(3,3),padding='same')(x)
f = MaxPool2D((2,2),padding='same')(f)
f = Conv2D(32,(3,3),padding='same')(f)
f = MaxPool2D((2,2),padding='same')(f)
f = Conv2D(64,(3,3),padding='same')(f)
f = MaxPool2D((2,2),padding='same')(f)
f = Flatten()(f)
y_pred = Dense(10, activation='softmax', name='y_pred' )(f)
model = Model( inputs=[x, y_true, is_weight], outputs=y_pred, name='train_only' )
model.add_loss( sample_loss( y_true, y_pred, is_weight ) )
model.compile( loss=None, optimizer='sgd' )
print model.summary()

Примечание: поскольку вы добавляете убыток через add_loss(), вам не нужно делать это через compile( loss=xxx ).

Что касается обучения модели, нет ничего особенного, кроме того, что вы переместили y_true в конец ввода. Смотри ниже

import numpy as np 
a = np.random.randn(8,32,32,3)
a_true = np.random.randn(8,10)
a_is_weight = np.random.randint(0,2,size=(8,1))
model.fit( [a, a_true, a_is_weight] )

Наконец, вы можете создать тестовую модель (которая разделяет все веса в model) для более легкого использования, т.е.

test_model = Model( inputs=x, outputs=y_pred, name='test_only' )
a_pred = test_model.predict( a )
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...