Я пытаюсь реализовать функцию потери мультикласса в тензорном потоке. Так как это игра с несколькими классами, мне нужно преобразовать вероятности каждого класса в одну горячую форму. Например, если моя сеть выдает эти вероятности:
[0,2, 0,6, 0,1, 0,1] (при условии 4 класса)
Мне нужно преобразовать это в:
[0 1 0 0]
Это можно сделать с помощью tf.argmax , за которым следует tf.one_hot
def generalized_dice_loss(labels, logits):
#labels shape [batch_size,128,128,64,1] dtype=float32
#logits shape [batch_size,128,128,64,7] dtype=float32
labels=tf.cast(labels,tf.int32)
smooth = tf.constant(1e-17)
shape = tf.TensorShape(logits.shape).as_list()
depth = int(shape[-1])
labels = tf.one_hot(labels, depth, dtype=tf.int32,axis=4)
labels = tf.squeeze(labels, axis=5)
logits = tf.argmax(logits,axis=4)
logits = tf.one_hot(logits, depth, dtype=tf.int32,axis=4)
numerator = tf.reduce_sum(labels * logits, axis=[1, 2, 3])
denominator = tf.reduce_sum(labels + logits, axis=[1, 2, 3])
numerator=tf.cast(numerator,tf.float32)
denominator=tf.cast(denominator,tf.float32)
loss = tf.reduce_mean(1.0 - 2.0*(numerator + smooth)/(denominator + smooth))
return loss
Проблема в том, что tf.argmax не дифференцируется, выдает ошибку:
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
Как решить эту проблему? Можем ли мы сделать то же самое, не используя tf.argmax?