Кросентропия под заказ metri c в модели Keras - PullRequest
0 голосов
/ 26 февраля 2020

Я решаю проблему несбалансированной классификации нескольких классов (3 класса) с помощью Keras. Чтобы справиться с несбалансированными классами, я использую пользовательский метри c, созданный функцией w_categorical_crossentropy, как указано здесь .

Однако моя реализация завершается ошибкой со следующей ошибкой:

TypeError: w_categorical_crossentropy () получил неожиданный аргумент ключевого слова 'sample_weight'

Я нигде не использую sample_weight.

Keras

import functools
import keras.backend as K
from itertools import product
from keras.optimizers import SGD, Adam, RMSprop

NUM_CLASSES = 3
BATCH_SIZE = 128

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
    y_pred_max_mat = K.cast(K.equal(y_pred, y_pred_max), K.floatx())
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask


dropout_rate=0.20
hidden_units=128

model = Sequential()

model.add(LSTM(
                units=nb_features, # the number of hidden states
                return_sequences=True, 
                input_shape=(timestamps,nb_features),
                kernel_regularizer=l2(0.01),
                recurrent_regularizer=l2(0.01),
                bias_regularizer=l2(0.01),
                dropout=dropout_rate,
                recurrent_dropout=0.20
              )
         )

model.add(Dense(units=hidden_units,
                kernel_initializer='normal',
                kernel_regularizer=l2(0.01),
                activation='relu'
               ))

model.add(BatchNormalization())

model.add(LeakyReLU(alpha=0.5))

model.add(Dropout(dropout_rate))

#model.add(TimeDistributed(Dense(1))) #units=round(timestamps/2),activation='relu')

model.add(Dense(units=hidden_units, 
                kernel_initializer='normal',
                kernel_regularizer=l2(0.01),
                activation='relu'))

model.add(BatchNormalization())

model.add(Dropout(dropout_rate))

model.add(Flatten())

model.add(Dense(units=hidden_units, 
                kernel_initializer='uniform',
                kernel_regularizer=l2(0.01),
                activation='relu'))

model.add(BatchNormalization())

model.add(Dropout(dropout_rate))

model.add(Dense(units=nb_classes,
                activation='softmax'))

w_array = np.ones((nb_classes,nb_classes))
w_array[1, 2] = 1.0
w_array[2, 1] = 1.2

ncce = functools.partial(w_categorical_crossentropy, weights=w_array)

# Define a performance metric
#sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)

rms = RMSprop()

model.compile(loss=ncce,
              metrics = ["accuracy"],
              optimizer= rms # "adam"
             )

Затем я также попытался использовать эту реализацию:

def weighted_categorical_crossentropy(weights):
    """
    A weighted version of keras.objectives.categorical_crossentropy

    Variables:
        weights: numpy array of shape (C,) where C is the number of classes

    Usage:
        weights = np.array([0.5,2,10]) # Class one at 0.5, class 2 twice the normal weights, class 3 10x.
        loss = weighted_categorical_crossentropy(weights)
        model.compile(loss=loss,optimizer='adam')
    """

    weights = K.variable(weights)

    def loss(y_true, y_pred):
        # scale predictions so that the class probas of each sample sum to 1
        y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
        # clip to prevent NaN's and Inf's
        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
        # calc
        loss = y_true * K.log(y_pred) * weights
        loss = -K.sum(loss, -1)
        return loss

    return loss

ncce = w_categorical_crossentropy(weights = np.array([0.5,2,10]))

Но в этом случае я получаю следующую ошибку:

IndexError: слишком много индексы для массива

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...