Можно ли написать пользовательскую функцию потерь, основанную на разнице выходных данных в пакете в Keras? - PullRequest
0 голосов
/ 12 мая 2019

Я пытаюсь реализовать функцию потерь в Керасе, которая может делать следующее:

Предположим, что y0, y1, ..., yn - модель пакетный вывод для пакетного ввода x0, x1, ..., xn, здесь batch_size - это n + 1, выход yi для каждого xi - скалярное значение, которое я хочу, чтобы функция потерь вычисляла полную потерю для этой партии следующим образом:

K.log (K.sigmoid (у1-у0)) + K.log (K.sigmoid (у2-у1)) + ... K.log (K.sigmoid (уп-ин-1))

Я думал использовать лямбда-слой, чтобы сначала преобразовать пакетный вывод [y0, y1, ..., yn] в [y1-y0, y2-y1, ..., yn-yn-1], а затем использовать настраиваемая функция потерь на преобразованном выходе.

Однако я не уверен, может ли Keras понять, что в лямбда-слое нет веса для обновления, и мне неясно , как Keras будет распространять градиент обратно через слой Lambda , как Keras обычно требует, чтобы каждый слой / функция потерь работала с одним входом сэмпла, но мой слой будет принимать весь выход партии сэмплов. Кто-нибудь решал подобные проблемы раньше? Спасибо!

1 Ответ

0 голосов
/ 12 мая 2019

Работает ли нарезка, как показано ниже, (хотя я не использую керас).

batch = 4
num_classes = 6
logits = tf.random.uniform(shape=[batch, num_classes])

logits1 = tf.slice(logits, (0, 0), [batch, num_classes-1])
logits2 = tf.slice(logits, (0, 1), [batch, num_classes-1])

delta = logits2 - logits1
loss = tf.reduce_sum(tf.log(tf.nn.sigmoid(delta)), axis=-1)

with tf.Session() as sess:
  logits, logits1, logits2, delta, loss  = sess.run([logits, logits1, logits2, 
                                                     delta, loss])

  print 'logits\n', logits
  print 'logits2\n', logits2
  print 'logits1\n', logits1
  print 'delta\n', delta
  print 'loss\n', loss

Результат:

logits
[[ 0.61241663  0.70075285  0.98333454  0.4117974   0.5943476   0.84245574]
 [ 0.02499413  0.22279179  0.70742595  0.34853518  0.7837007   0.88074362]
 [ 0.35030317  0.36670768  0.64244425  0.87957716  0.22823489  0.45076978]
 [ 0.38116801  0.39040041  0.82510674  0.64789391  0.45415008  0.03520513]]
logits2
[[ 0.70075285  0.98333454  0.4117974   0.5943476   0.84245574]
 [ 0.22279179  0.70742595  0.34853518  0.7837007   0.88074362]
 [ 0.36670768  0.64244425  0.87957716  0.22823489  0.45076978]
 [ 0.39040041  0.82510674  0.64789391  0.45415008  0.03520513]]
logits1
[[ 0.61241663  0.70075285  0.98333454  0.4117974   0.5943476 ]
 [ 0.02499413  0.22279179  0.70742595  0.34853518  0.7837007 ]
 [ 0.35030317  0.36670768  0.64244425  0.87957716  0.22823489]
 [ 0.38116801  0.39040041  0.82510674  0.64789391  0.45415008]]
delta
[[ 0.08833623  0.28258169 -0.57153714  0.18255019  0.24810815]
 [ 0.19779766  0.48463416 -0.35889077  0.43516552  0.09704292]
 [ 0.01640451  0.27573657  0.23713291 -0.65134227  0.22253489]
 [ 0.0092324   0.43470633 -0.17721283 -0.19374382 -0.41894495]]
loss
[-3.41376281 -3.11249781 -3.49031925 -3.69255161]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...