Я определил следующую функцию потерь:
def dice_coef(y_true, y_pred):
y_pred = K.gather(y_pred, tf.where(y_pred>0.5))
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersect = K.sum(y_pred_f * y_true_f)
denominator = K.sum(y_pred_f) + K.sum(y_true_f)
dice_score = K.constant(2.) * intersect / (denominator + K.constant(.01))
return dice_score
Поскольку я хочу, чтобы y_pred был 0 и 1 с, я последовал за потоком в стеке, который предложил использовать y_pred = K.gather(y_pred, tf.where(y_pred>0.5))
. Однако я получаю сообщение об ошибке:
ResourceExhaustedError: OOM при выделении тензора с формой [662885,4,144,144,1]
Есть какое-нибудь решение этой проблемы?