Я пытаюсь реализовать пользовательскую функцию потерь в кератах для задачи «Частичное изучение меток». В моем учебном наборе каждому учебному экземпляру назначается набор из двух ярлыков кандидатов, только один из которых является правильным. Для этой цели я хочу использовать функцию потерь, которая во время обучения будет рассчитывать потери для каждой метки и выбирать потери с минимальным значением. Упрощенная версия этой функции будет выглядеть примерно так:
def custom_loss(y_true, y_pred):
num_labels = tf.reduce_sum(y_true) # [0,1,0,0,1]
if num_labels > 1: #create 2 seperate vectors
y_true_1 = ? # [0,1,0,0,0]
y_true_2 = ? # [0,0,0,0,1]
loss_1 = K.categorical_crossentropy(y_true_1, y_pred)
loss_2 = K.categorical_crossentropy(y_true_2, y_pred)
loss = minimum(loss_1, loss_2)
else:
loss = K.categorical_crossentropy(y_true, y_pred)
return loss
Я пытался сделать это так:
y_true = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 0.])
y_pred = tf.constant([.9, .05, .05, .5, .89, .6, .05, .01, .94])
def custom_loss(y_true, y_pred):
def train_loss():
y_train_copy = tf.Variable(0, dtype=y_true.dtype)
y_train_copy = tf.assign(y_train_copy, y_true, validate_shape=False)
label_cls = tf.where(tf.equal(y_true,1))
raplace = tf.Variable([0.]) #Variable
y_true_1 = tf.compat.v1.scatter_nd_update(y_train_copy, [label_cls[0]], raplace) # [0,1,0,0,0]
y_true_2 = tf.compat.v1.scatter_nd_update(y_train_copy, [label_cls[1]], raplace) # [0,0,0,0,1]
loss_1 = K.categorical_crossentropy(y_true_1, y_pred)
loss_2 = K.categorical_crossentropy(y_true_2, y_pred)
min_loss = tf.minimum(loss_1, loss_2)
return min_loss
num_labels = tf.reduce_sum(y_true) # [0,1,0,0,1]
loss = tf.cond(num_labels > 1,
lambda: train_loss(),
lambda: K.categorical_crossentropy(y_true, y_pred)) #
return loss
loss = custom_loss(y_true, y_pred)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(loss))
Проблема в том, что по какой-то причине, независимо от того, как я пытаюсьчтобы получить минимум из двух потерь, я получаю 0.0, даже когда loss_1 и loss_2 определенно не равны 0
Есть идеи почему? или лучшая идея для реализации этой функции?