1

I am trying to implement a multi class dice loss function in tensorflow. Since it is multi class dice, I need to convert the probabilities of each class into its one-hot form. For example, if my network outputs these probabilities:
[0.2, 0.6, 0.1, 0.1] (assuming 4 classes)
I need to convert this into:
[0 1 0 0]
This can be done by using tf.argmax followed by tf.one_hot

def generalized_dice_loss(labels, logits):
 #labels shape [batch_size,128,128,64,1] dtype=float32
 #logits shape [batch_size,128,128,64,7] dtype=float32
 labels=tf.cast(labels,tf.int32)
 smooth = tf.constant(1e-17)
 shape = tf.TensorShape(logits.shape).as_list()
 depth = int(shape[-1])
 labels = tf.one_hot(labels, depth, dtype=tf.int32,axis=4)
 labels = tf.squeeze(labels, axis=5)
 logits = tf.argmax(logits,axis=4)
 logits = tf.one_hot(logits, depth, dtype=tf.int32,axis=4)
 numerator = tf.reduce_sum(labels * logits, axis=[1, 2, 3])
 denominator = tf.reduce_sum(labels + logits, axis=[1, 2, 3])
 numerator=tf.cast(numerator,tf.float32)
 denominator=tf.cast(denominator,tf.float32)
 loss = tf.reduce_mean(1.0 - 2.0*(numerator + smooth)/(denominator + smooth))
 return loss

Problem is, tf.argmax is not differentiable, It will throw an error:

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

How to solve this problem? Can we do the same thing without using tf.argmax?

1 Answer 1

3

Take a look at How is the smooth dice loss differentiable?. You won't need to do the conversion (convert [0.2, 0.6, 0.1, 0.1] to [0 1 0 0]). Just leave them as the continuous value between 0 and 1.

If I understand correctly, the loss function is just a surrogate to achieve your expected objective. Even though it is not the same, as long as it is a good proxy, it is fine (otherwise, it is not differentiable).

In the evaluation time, feel free to use the tf.argmax to get the real metric.

Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.