Использовать tf.metrics в Керасе? - PullRequest
0 голосов
/ 26 мая 2018

Меня особенно интересует specificity_at_sensitivity.Просматривая документы Keras :

from keras import metrics

model.compile(loss='mean_squared_error',
              optimizer='sgd',
              metrics=[metrics.mae, metrics.categorical_accuracy])

Но похоже, что список metrics должен иметь функции arity 2, принимая (y_true, y_pred) и возвращая единственное тензорное значение.


РЕДАКТИРОВАТЬ: В настоящее время вот как я делаю вещи:

from sklearn.metrics import confusion_matrix

predictions = model.predict(x_test)
y_test = np.argmax(y_test, axis=-1)
predictions = np.argmax(predictions, axis=-1)
c = confusion_matrix(y_test, predictions)
print('Confusion matrix:\n', c)
print('sensitivity', c[0, 0] / (c[0, 1] + c[0, 0]))
print('specificity', c[1, 1] / (c[1, 1] + c[1, 0]))

Недостатком этого подхода является то, что я получаю результат, который мне небезразличен только после окончания обучения.Предпочел бы получать метрики каждые 10 эпох или около того.

Ответы [ 2 ]

0 голосов
/ 26 мая 2018

Я обнаружил связанную проблему на github , и кажется, что tf.metrics по-прежнему не поддерживается моделями Keras.Однако, если вы очень заинтересованы в использовании tf.metrics.specificity_at_sensitivity , я бы предложил следующий обходной путь (на основе решения БогданРужа ):

def specificity_at_sensitivity(sensitivity, **kwargs):
    def metric(labels, predictions):
        # any tensorflow metric
        value, update_op = tf.metrics.specificity_at_sensitivity(labels, predictions, sensitivity, **kwargs)

        # find all variables created for this metric
        metric_vars = [i for i in tf.local_variables() if 'specificity_at_sensitivity' in i.name.split('/')[2]]

        # Add metric variables to GLOBAL_VARIABLES collection.
        # They will be initialized for new session.
        for v in metric_vars:
            tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, v)

        # force to update metric values
        with tf.control_dependencies([update_op]):
            value = tf.identity(value)
            return value
    return metric


model.compile(loss='mean_squared_error',
              optimizer='sgd',
              metrics=[metrics.mae,
                       metrics.categorical_accuracy,
                       specificity_at_sensitivity(0.5)])

ОБНОВЛЕНИЕ:

Вы можете использовать model.evaluate , чтобы получить метрики после обучения.

0 голосов
/ 26 мая 2018

Я не думаю, что существует строгое ограничение только для двух входящих аргументов, в metrics.py функция представляет собой всего три входящих аргумента, но k выбирает значение по умолчанию 5.

def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
    return K.mean(K.in_top_k(y_pred, K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1)
...