Как реализовать функцию потери мультиклассовых кубиков без использования функции argmax (argmax не дифференцируется)? - PullRequest
0 голосов
/ 13 мая 2019

Я пытаюсь реализовать функцию потери мультикласса в тензорном потоке. Так как это игра с несколькими классами, мне нужно преобразовать вероятности каждого класса в одну горячую форму. Например, если моя сеть выдает эти вероятности:
[0,2, 0,6, 0,1, 0,1] (при условии 4 класса)
Мне нужно преобразовать это в:
[0 1 0 0]
Это можно сделать с помощью tf.argmax , за которым следует tf.one_hot

def generalized_dice_loss(labels, logits):
 #labels shape [batch_size,128,128,64,1] dtype=float32
 #logits shape [batch_size,128,128,64,7] dtype=float32
 labels=tf.cast(labels,tf.int32)
 smooth = tf.constant(1e-17)
 shape = tf.TensorShape(logits.shape).as_list()
 depth = int(shape[-1])
 labels = tf.one_hot(labels, depth, dtype=tf.int32,axis=4)
 labels = tf.squeeze(labels, axis=5)
 logits = tf.argmax(logits,axis=4)
 logits = tf.one_hot(logits, depth, dtype=tf.int32,axis=4)
 numerator = tf.reduce_sum(labels * logits, axis=[1, 2, 3])
 denominator = tf.reduce_sum(labels + logits, axis=[1, 2, 3])
 numerator=tf.cast(numerator,tf.float32)
 denominator=tf.cast(denominator,tf.float32)
 loss = tf.reduce_mean(1.0 - 2.0*(numerator + smooth)/(denominator + smooth))
 return loss

Проблема в том, что tf.argmax не дифференцируется, выдает ошибку:

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

Как решить эту проблему? Можем ли мы сделать то же самое, не используя tf.argmax?

1 Ответ

0 голосов
/ 13 мая 2019

Взгляните на Как дифференцируется потеря гладких костей? .Вам не нужно делать конвертацию (конвертировать [0.2, 0.6, 0.1, 0.1] в [0 1 0 0]).Просто оставьте их как непрерывное значение от 0 до 1.

Если я правильно понимаю, функция потерь - это просто суррогат для достижения ожидаемой цели.Несмотря на то, что это не то же самое, пока это хороший прокси, это хорошо (иначе, это не дифференцируемо).

Во время оценки не стесняйтесь использовать tf.argmax, чтобы получить реальную метрику.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...