В TensorFlow 2.0 есть класс tf.keras.metrics.AUC
. Его можно легко добавить в список метрик метода compile
следующим образом.
# Example taken from the documentation
model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.AUC()])
Однако в моем случае выходные данные моей нейронной сети имеют тензор NxM
, где N
- размер партии, а M
- количество отдельных выходов. Я хотел бы вычислить AU C metri c для каждого из этих M
выходов отдельно (для всех N
экземпляров пакета). Итак, должно быть M
AU C метрик, каждая из которых рассчитывается с N
наблюдениями. Я пытался создать собственную метри c, но у меня возникли некоторые проблемы. Следующее - моя первая попытка.
def get_custom_auc(output):
auc = tf.metrics.AUC()
@tf.function
def custom_auc(y_true, y_pred):
y_true = y_true[:, output]
y_pred = y_pred[:, output]
auc.update_state(y_true, y_pred)
return auc.result()
custom_auc.__name__ = "custom_auc_" + str(output)
return custom_auc
Необходимость переименования custom_auc.__name__
описана в следующем посте: Возможно ли иметь metri c, который возвращает массив (или тензор) ) а не число? . Однако эта реализация вызывает ошибку.
tenorflow. python .framework.errors_impl.InvalidArgumentError: сбой утверждения: [предсказания должны быть> = 0] [Условие x> = y не содержит элемент -wise:] [x (strided_slice_1: 0) =] [3.14020467 3.06779885 2.86414027 ...] [y (Cast_1 / x: 0) =] [0] [[{{метрика узла / custom_auc_2 / StatefulPartitionedCall / assert_greater_equal / Assert / AssertGuard / else / _161 / Assert}}]] [Op: __ inference_keras_scratch_graph_5149]
Я также пытался создать объект AUC
внутри custom_auc
, но это невозможно, потому что я используя @tf.function
, поэтому я получу ошибку ValueError: tf.function-decorated function tried to create variables on non-first call.
. Даже если я удаляю @tf.function
(который мне может понадобиться, потому что я могу использовать некоторые операторы if-else внутри реализации), я получаю еще одну ошибку
tenorflow. python .framework.errors_impl. FailedPreconditionError: Ошибка при чтении переменной ресурса _AnonymousVar33 из контейнера: localhost. Это может означать, что переменная была неинициализирована. Не найдено: ресурс localhost / _AnonymousVar33 / N10tensorflow3VarE не существует. [[node metrics / custom_auc_0 / add / ReadVariableOp (определено в /train.py:173)]] [Op: __ inference_keras_scratch_graph_5174]
Обратите внимание, что в настоящее время я добавляю эти метрики AU C по одному на каждый из M
выходов, как описано в этого ответа . Кроме того, я не могу просто вернуть объект auc
, потому что, очевидно, Керас ожидает, что выходные данные пользовательского метри c будут тензорными, а не объектом AU C. Поэтому, если вы это сделаете, вы получите следующую ошибку:
TypeError: Для совместимости с tf.contrib.eager.defun, функции Python должны возвращать ноль или более Tensors; при компиляции .custom_au c в 0x1862e6680> нашел возвращаемое значение типа, который не является тензором.
Я также пытался реализовать пользовательский класс metri c следующим образом.
class CustomAUC(tf.metrics.Metric):
def __init__(self, num_outputs, name="custom_auc", **kwargs):
super(CustomAUC, self).__init__(name=name, **kwargs)
assert num_outputs >= 1
self.num_outputs = num_outputs
self.aucs = [tf.metrics.AUC() for _ in range(self.num_outputs)]
def update_state(self, y_true, y_pred, sample_weight=None):
for output in range(self.num_outputs):
y_true1 = y_true[:, output]
y_pred1 = y_pred[:, output]
self.aucs[output].update_state(y_true1, y_pred1)
def result(self):
return [auc.result() for auc in self.aucs]
Однако в настоящее время я получаю сообщение об ошибке
ValueError: Shapes (200,) и () несовместимы
Эта ошибка кажется быть связанным с reset_states
, поэтому, возможно, мне следует переопределить этот метод. Фактически, если я переопределяю reset_states
со следующей реализацией
def reset_states(self):
for auc in self.aucs:
auc.reset_states()
, я больше не получаю эту ошибку, но я получаю другую ошибку
tenorflow. python .framework.errors_impl.InvalidArgumentError: утверждение не выполнено: [предсказания должны быть> = 0] [Условие x> = y не содержало поэлементно:] [x (strided_slice_1: 0) =] [-1.38822043 1.24234951 -0.254447281 ... ] [y (Cast_1 / x: 0) =] [0] [[{{метрики узла / custom_auc / PartitionedFunctionCall / assert_greater_equal / Assert / AssertGuard / else / _98 / Assert}}]] [Op: __ inference_keras_scratch_graph_5248]
Итак, как мне реализовать этот пользовательский AU C metri c, по одному для каждого из M
выходов сети? По сути, я хочу сделать что-то похожее на решение, описанное в этом ответе , но с AU C metri c.
Я также открыл проблему, связанную с в системе отслеживания ошибок TensorFlow's Github.