Я пытаюсь реализовать взвешенную перекрестную энтропию из TF в Керасе.Документация с сайта TF: https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits
Вот что я делаю:
import tensorflow as tf
from keras import backend as K
# Create the custom loss function
def weighted_binary_crossentropy(weights):
def w_binary_crossentropy(y_true, y_pred):
return K.mean(tf.nn.weighted_cross_entropy_with_logits(
y_true,
y_pred,
weights,
name=None
), axis=-1)
return w_binary_crossentropy
# Optimizers, Loss and Compile
adam = Adam(lr=0.0001)
weighted_loss = weighted_binary_crossentropy(weights=1)
model.compile(optimizer=adam, loss=weighted_loss, metrics=['accuracy'])
Тренинг начался, но Loss не обновляется / застревает.Я ожидаю, что если я установлю вес на 1, то результат будет таким же, как и при стандартной кросс-энтропийной потере.Я что-то пропустил?