Показатель точности TensorFlow Keras под капотом - PullRequest
3 голосов
/ 19 сентября 2019

При построении классификатора с использованием TensorFlow Keras часто следят за точностью модели, указав metrics=['accuracy'] на этапе компиляции:

model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])

Это ведет себя правильно, независимо от того, выводит ли модель логиты или вероятности классов, инезависимо от того, ожидает ли модель наземные метки истинности как векторы с горячим кодированием или целочисленные индексы (т. е. целые числа в интервале [0, n_classes)).

Это не тот случай, если кто-то хочет использовать перекрестныепотеря энтропии: каждая из четырех комбинаций упомянутых выше случаев требует, чтобы на этапе компиляции передавалось различное значение loss:

  1. Если модель выдает вероятности и метки истинности землис горячим кодированием, тогда loss='categorical_crossentropy' работает.

  2. Если модель выдает вероятности, а метки истинности заземления являются целочисленными индексами, то loss='sparse_categorical_crossentropy' работает.

  3. Если модель выводит логиты, а метки истинности заземления кодируются одним горячим кодом, то loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True) работает.

  4. IЕсли модель выводит логиты, а метки истинности основания являются целочисленными индексами, то loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) работает.

Кажется, что просто указание loss='categorical_crossentropy' недостаточно надежно для обработки этих четырех случаев,тогда как указание metrics=['accuracy'] является достаточно надежным.


Вопрос Что происходит за кулисами, когда пользователь указывает metrics=['accuracy'] на этапе компиляции модели, которыйпозволяет правильно вычислить точность независимо от того, выводит ли модель логиты или вероятности, а также являются ли метки истинности земли кодами с горячим кодированием или целочисленными индексами?


Я подозреваю, что логиты и вероятностиСлучай прост, так как предсказанный класс может быть получен как argmax в любом случае, но в идеале я хотел бы указать, где в исходном коде TensorFlow 2 вычисления фактически выполняются.

Обратите внимание, что я в настоящее времяиспользуя TensorFlow 2.0.0-rc1 .


Редактировать В чистом видеKeras, metrics=['accuracy'] явно обрабатывается в методе Model.compile .

1 Ответ

2 голосов
/ 19 сентября 2019

Нашел: это обработано в tensorflow.python.keras.engine.training_utils.get_metric_function.В частности, выходная форма проверяется, чтобы определить, какую функцию точности использовать.

Для уточнения, в текущей реализации Model.compile либо делегирует обработку метрики Model._compile_eagerly (если выполняется с нетерпением) или делает это напрямую.В любом случае вызывается Model._cache_output_metric_attributes, что вызывает collect_per_output_metric_info как для невзвешенной, так и для взвешенной метрики.Эта функция перебирает предоставленные метрики, вызывая get_metric_function для каждого.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...