Настраиваемая двоичная потеря кроссцентропии в кератах, которая игнорирует столбцы без ненулевых значений - PullRequest
1 голос
/ 21 сентября 2019

Я пытаюсь сегментировать данные, где метка может быть довольно разреженной.Поэтому я хочу вычислять только градиенты в столбцах, которые имеют хотя бы одно ненулевое значение.

Я пробовал некоторые методы, где я применяю дополнительный ввод, который является маской этих ненулевых столбцов, но учитывая, что все необходимыеинформация уже содержится в y_true, метод, который ищет только маску y_true, чтобы определенно найти маску, определенно предпочтительнее.

Если бы я реализовал ее с помощью numpy, она бы выглядела примерно так: *В этом примере 1007 *

def loss(y_true, y_pred):
    indices = np.where(np.sum(y_true, axis=1) > 0)
    return binary_crossentropy(y_true[indices], y_pred[indices])

y_true и y_pred являются векторизованными 2D-изображениями.

Как это можно «перевести» в дифференцируемую функцию потерь Кераса?

1 Ответ

0 голосов
/ 21 сентября 2019

Использование tf -совместимых операций через tf и keras.backend:

import tensorflow as tf
import keras.backend as K
from keras.losses import binary_crossentropy

def custom_loss(y_true, y_pred):
    indices = K.squeeze(tf.where(K.sum(y_true, axis=1) > 0))
    y_true_sparse = K.cast(K.gather(y_true, indices), dtype='float32')
    y_pred_sparse = K.cast(K.gather(y_pred, indices), dtype='float32')
    return binary_crossentropy(y_true_sparse, y_pred_sparse) # returns a tensor

Я не уверен в точных характеристиках размерности вашей проблемы, но потери должны быть оценены до одного значения- что выше не делает, так как вы передаете многомерные прогнозы и метки.Чтобы уменьшить яркость, оберните вышеприведенный код, например, K.mean.Пример:

y_true = np.random.randint(0,2,(10,2))
y_pred = np.abs(np.random.randn(10,2))
y_pred /= np.max(y_pred) # scale between 0 and 1

print(K.get_value(custom_loss(y_true, y_pred))) # get_value evaluates returned tensor
print(K.get_value(K.mean(custom_loss(y_true, y_pred))
>> [1.1489482  1.2705883  0.76229745  5.101402  3.1309896] # sparse; 5 / 10 results
>> 2.28284 # single value, as required

(И наконец, обратите внимание, что эта редкость смещает потери, исключая все нулевые столбцы из общего числа меток / пред.или K.size)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...