Вывод кастомной потери в керасе - PullRequest
0 голосов
/ 25 января 2019

Я знаю, что в Keras есть много вопросов, касающихся пользовательских функций потери, но я не смог ответить даже после 3 часов поиска в Google.

Вот очень упрощенный пример моей проблемы.Я понимаю, что этот пример бессмыслен, но я приведу его для простоты, мне, очевидно, нужно реализовать что-то более сложное.

from keras.backend import binary_crossentropy
from keras.backend import mean
def custom_loss(y_true, y_pred):

    zeros = tf.zeros_like(y_true)
    index_of_zeros = tf.where(tf.equal(zeros, y_true))
    ones = tf.ones_like(y_true)
    index_of_ones = tf.where(tf.equal(ones, y_true))

    zero = tf.gather(y_pred, index_of_zeros)
    one = tf.gather(y_pred, index_of_ones)

    loss_0 = binary_crossentropy(tf.zeros_like(zero), zero)
    loss_1 = binary_crossentropy(tf.ones_like(one), one)

    return mean(tf.concat([loss_0, loss_1], axis=0))

Я не понимаю, почему тренировка сети с вышеуказанной функцией потерь в наборе данных двух классов не дает того же результата, что и тренировка со встроенной функцией потери binary-crossentropy.Спасибо!

РЕДАКТИРОВАТЬ : я отредактировал фрагмент кода, добавив среднее значение согласно комментариям ниже.Я все еще получаю такое же поведение как бы то ни было.

1 Ответ

0 голосов
/ 28 января 2019

Я наконец понял это.Функция tf.where ведет себя очень по-разному, когда фигура «неизвестна».Чтобы исправить приведенный выше фрагмент, просто вставьте следующие строки сразу после объявления функции:

y_pred = tf.reshape(y_pred, [-1])
y_true = tf.reshape(y_true, [-1])
...