Я пытаюсь реализовать функцию потерь в этой статье .
Briefly, I
is input, P
is ground truth, lambda is a constant, f(I)
is the output of a segmentation network, g(I)
is another auxiliary network, and s(f(I))
is a Sobel filter applying on f(I)
.
Since the second term need both output of f
and g
, so I simply add a Sobel filter layer after the output layer of f
to get s(f(I))
, then using a subtract layer to compute |s(f(I)) - g(I)|
. Therefore, my model actually has 2 output, one is the first term in the loss function, another is the second.
However, my model doesn't work as I expected. So, I want to know if my loss functions have some mistakes, or I shouldn't use this way to build my model(mentioned in the previous paragraph)?
My model looks like this введите описание изображения здесь
Вот мой код.
loss1()
используется для первого члена. loss2()
используется для второго члена.
y_pred
в loss1()
это f(I)
, y_pred
в loss2()
это |s(f(I)) - g(I)|
.
def loss1(y_true, y_pred):
#first term
#only consider pixels that are annotated
y_true = K.clip(y_true, K.epsilon(), 1-K.epsilon())
y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())
#get point and boundary annotation separately
y_true_pos = y_true[:,:,:,0]
y_true_neg = y_true[:,:,:,1]
#flatten, (batch, pixels)
y_true_pos_f = K.batch_flatten(y_true_pos)
y_true_neg_f = K.batch_flatten(y_true_neg)
y_pred_f = K.batch_flatten(y_pred)
#only consider annotated pixels
#(batch, )
y_true_pos_count = K.sum(y_true_pos_f, axis=-1)
y_true_neg_count = K.sum(y_true_neg_f, axis=-1)
#cross_entropy of each image
#(batch, )
cross_entropy_pos = K.sum(-y_true_pos_f * K.log(y_pred_f), axis=-1)
cross_entropy_neg = K.sum(-y_true_neg_f * K.log(1-y_pred_f), axis=-1)
#loss_pos = K.mean(cross_entropy_pos / y_true_pos_count)
#loss_neg = K.mean(cross_entropy_neg / y_true_neg_count)
loss_pos = cross_entropy_pos / y_true_pos_count
loss_neg = cross_entropy_neg / y_true_neg_count
return 1.0 * loss_pos + 0.1 * loss_neg
#return loss
def loss2(y_true, y_pred):
#second term
return K.mean(K.batch_flatten(K.abs(y_pred)), axis=-1)