При построении классификатора с использованием TensorFlow Keras часто следят за точностью модели, указав metrics=['accuracy']
на этапе компиляции:
model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])
Это ведет себя правильно, независимо от того, выводит ли модель логиты или вероятности классов, инезависимо от того, ожидает ли модель наземные метки истинности как векторы с горячим кодированием или целочисленные индексы (т. е. целые числа в интервале [0, n_classes)
).
Это не тот случай, если кто-то хочет использовать перекрестныепотеря энтропии: каждая из четырех комбинаций упомянутых выше случаев требует, чтобы на этапе компиляции передавалось различное значение loss
:
Если модель выдает вероятности и метки истинности землис горячим кодированием, тогда loss='categorical_crossentropy'
работает.
Если модель выдает вероятности, а метки истинности заземления являются целочисленными индексами, то loss='sparse_categorical_crossentropy'
работает.
Если модель выводит логиты, а метки истинности заземления кодируются одним горячим кодом, то loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True)
работает.
IЕсли модель выводит логиты, а метки истинности основания являются целочисленными индексами, то loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
работает.
Кажется, что просто указание loss='categorical_crossentropy'
недостаточно надежно для обработки этих четырех случаев,тогда как указание metrics=['accuracy']
является достаточно надежным.
Вопрос Что происходит за кулисами, когда пользователь указывает metrics=['accuracy']
на этапе компиляции модели, которыйпозволяет правильно вычислить точность независимо от того, выводит ли модель логиты или вероятности, а также являются ли метки истинности земли кодами с горячим кодированием или целочисленными индексами?
Я подозреваю, что логиты и вероятностиСлучай прост, так как предсказанный класс может быть получен как argmax в любом случае, но в идеале я хотел бы указать, где в исходном коде TensorFlow 2 вычисления фактически выполняются.
Обратите внимание, что я в настоящее времяиспользуя TensorFlow 2.0.0-rc1 .
Редактировать В чистом видеKeras, metrics=['accuracy']
явно обрабатывается в методе Model.compile
.