Как реализовать пользовательскую функцию потерь с несколькими входами и выходами - PullRequest
0 голосов
/ 10 июля 2020

Я пытаюсь реализовать функцию потерь в этой статье . enter image description here

Briefly, I is input, P is ground truth, lambda is a constant, f(I) is the output of a segmentation network, g(I) is another auxiliary network, and s(f(I)) is a Sobel filter applying on f(I).

Since the second term need both output of f and g, so I simply add a Sobel filter layer after the output layer of f to get s(f(I)), then using a subtract layer to compute |s(f(I)) - g(I)|. Therefore, my model actually has 2 output, one is the first term in the loss function, another is the second.

However, my model doesn't work as I expected. So, I want to know if my loss functions have some mistakes, or I shouldn't use this way to build my model(mentioned in the previous paragraph)?

My model looks like this введите описание изображения здесь

Вот мой код.

loss1() используется для первого члена. loss2() используется для второго члена.

y_pred в loss1() это f(I), y_pred в loss2() это |s(f(I)) - g(I)|.

def loss1(y_true, y_pred):
#first term
#only consider pixels that are annotated
    y_true = K.clip(y_true, K.epsilon(), 1-K.epsilon())
    y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())

    #get point and boundary annotation separately
    y_true_pos = y_true[:,:,:,0]
    y_true_neg = y_true[:,:,:,1]

    #flatten, (batch, pixels)
    y_true_pos_f = K.batch_flatten(y_true_pos)
    y_true_neg_f = K.batch_flatten(y_true_neg)
    y_pred_f = K.batch_flatten(y_pred)

    #only consider annotated pixels
    #(batch, )
    y_true_pos_count = K.sum(y_true_pos_f, axis=-1)
    y_true_neg_count = K.sum(y_true_neg_f, axis=-1)

    #cross_entropy of each image
    #(batch, )
    cross_entropy_pos = K.sum(-y_true_pos_f * K.log(y_pred_f), axis=-1)
    cross_entropy_neg = K.sum(-y_true_neg_f * K.log(1-y_pred_f), axis=-1)

    #loss_pos = K.mean(cross_entropy_pos / y_true_pos_count)
    #loss_neg = K.mean(cross_entropy_neg / y_true_neg_count)

    loss_pos = cross_entropy_pos / y_true_pos_count
    loss_neg = cross_entropy_neg / y_true_neg_count

    return 1.0 * loss_pos + 0.1 * loss_neg 
        
    #return loss 

def loss2(y_true, y_pred):
#second term
    return K.mean(K.batch_flatten(K.abs(y_pred)), axis=-1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...