Пользовательская метрика keras с весами выборки - PullRequest
0 голосов
/ 22 февраля 2019

Я пытаюсь определить пользовательскую метрику в Keras, которая учитывает веса выборки.При подгонке модели я использую выборочные веса следующим образом:

training_history = model.fit(
        train_data,
        train_labels,
        sample_weight = train_weights,
        epochs = num_epochs,
        batch_size = 128,
        validation_data = (validation_data, validatation_labels, validation_weights ),
    )

Примером пользовательской метрики, которую я использую, является AUC (площадь под кривой roc), которую я определил следующим образом:

from keras import backend as K
import tensorflow as tf

def auc(true_labels, predictions, weights = None):
    auc = tf.metrics.auc(true_labels, predictions, weights = weights)[1]
    K.get_session().run(tf.local_variables_initializer())
    return auc

и я использую эту метрику при компиляции модели:

model.compile(
        optimizer = optimizer,
        loss = 'binary_crossentropy',
        metrics = ['accuracy', auc]
    )

Но, насколько я могу судить, метрика не учитывает веса выборки.Фактически я проверил это, сравнив значение метрики, которое я вижу при обучении модели с использованием пользовательской метрики, определенной выше, с тем, что я получаю, вычисляя ее самостоятельно из выходных данных модели и весов выборки, которые действительно дают очень разные результаты.Как бы я определил показанную выше метрику auc для учета веса выборки?

1 Ответ

0 голосов
/ 22 февраля 2019

Вы можете обернуть свою метрику другой функцией, которая принимает sample_weights в качестве аргумента:

def auc(weights):
    def metric(true_labels, predictions):
        auc = tf.metrics.auc(true_labels, predictions, weights=weights)[1]
        K.get_session().run(tf.local_variables_initializer())
        return auc
    return metric

И затем определить дополнительный входной заполнитель, который будет получать веса выборки:

sample_weights = Input(shape=(1,))

Ваша модель может быть скомпилирована следующим образом:

model.compile(
    optimizer = optimizer,
    loss = 'binary_crossentropy',
    metrics = ['accuracy', auc(sample_weights)]
)

ПРИМЕЧАНИЕ: Не проверено.

...