Losses for Image Segmentation

7 minute read

In this post, I will implement some of the most common losses for image segmentation in Keras/TensorFlow. I will only consider the case of two classes (i.e. binary). If you know any other losses, let me know and I will add them.

16.08.2019: improved overlap measures, added CE+DL loss

Cross Entropy

Let and . The predictions are given by the logistic/sigmoid function and . Then cross entropy (CE) can be defined as follows:

In Keras, the loss function is binary_crossentropy(y_true, y_pred) and in TensorFlow, it is softmax_cross_entropy_with_logits_v2.

Weighted cross entropy

Weighted cross entropy (WCE) is a variant of CE where all positive examples get weighted by some coefficient. It is used in the case of class imbalance. For example, when you have an image with 10% black pixels and 90% white pixels, regular CE won’t work very well.

WCE can be defined as follows:

To decrease the number of false negatives, set . To decrease the number of false positives, set .

In TensorFlow, the loss function is weighted_cross_entropy_with_logits. In Keras, we have to implement our own function:

def weighted_cross_entropy(beta):
  def convert_to_logits(y_pred):
      # see https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/keras/backend.py#L3525
      y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())

      return tf.log(y_pred / (1 - y_pred))

  def loss(y_true, y_pred):
    y_pred = convert_to_logits(y_pred)
    loss = tf.nn.weighted_cross_entropy_with_logits(logits=y_pred, targets=y_true, pos_weight=beta)

    # or reduce_sum and/or axis=-1
    return tf.reduce_mean(loss)

  return loss

The function convert_to_logits is necessary, because we applied the sigmoid function on y_pred in the last layer of our CNN. Hence, in order reverse this step, we have to calculate

Balanced cross entropy

Balanced cross entropy (BCE) is similar to WCE. The only difference is that we weight also the negative examples.

BCE can be defined as follows:

In Keras, it can be implemented as follows:

def balanced_cross_entropy(beta):
  def convert_to_logits(y_pred):
      # see https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/keras/backend.py#L3525
      y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())

      return tf.log(y_pred / (1 - y_pred))

  def loss(y_true, y_pred):
    y_pred = convert_to_logits(y_pred)
    pos_weight = beta / (1 - beta)
    loss = tf.nn.weighted_cross_entropy_with_logits(logits=y_pred, targets=y_true, pos_weight=pos_weight)

    # or reduce_sum and/or axis=-1
    return tf.reduce_mean(loss * (1 - beta))

  return loss

When , the denominator in pos_weight is not defined. This can happen, when beta is not a fixed value. For example, the paper [1] uses:

beta = tf.reduce_sum(1 - y_true) / (BATCH_SIZE * HEIGHT * WIDTH)

In this case, add to a small value like tf.keras.backend.epsilon() or use tf.clip_by_value.

Focal loss

Focal loss (FL) [2] tries to down-weight the contribution of easy examples so that the CNN focuses more on hard examples.

FL can be defined as follows:

When $\gamma = 0$, we obtain BCE.

This time we cannot use weighted_cross_entropy_with_logits to implement FL in Keras. We will derive instead our own focal_loss_with_logits function.

And the implementation is then:

def focal_loss(alpha=0.25, gamma=2):
  def focal_loss_with_logits(logits, targets, alpha, gamma, y_pred):
    weight_a = alpha * (1 - y_pred) ** gamma * targets
    weight_b = (1 - alpha) * y_pred ** gamma * (1 - targets)
    
    return (tf.log1p(tf.exp(-tf.abs(logits))) + tf.nn.relu(-logits)) * (weight_a + weight_b) + logits * weight_b 

  def loss(y_true, y_pred):
    y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())
    logits = tf.log(y_pred / (1 - y_pred))

    loss = focal_loss_with_logits(logits=logits, targets=y_true, alpha=alpha, gamma=gamma, y_pred=y_pred)

    # or reduce_sum and/or axis=-1
    return tf.reduce_mean(loss)

  return loss

Distance to the nearest cell

The paper [3] adds to cross entropy a distance function to force the CNN to learn the separation border between touching objects. The following function adds to BCE a distance term:

where and are two functions that calculate the distance to the nearest and second nearest cell.

Calculating the exponential term inside the loss function would slow down the training considerably. Hence, pass the distance to the neural network together with the image input.

The following code is a variation that calculates the distance only to one object.

from scipy.spatial import distance_matrix
import numpy as np

...

not_zeros = np.argwhere(img != 0)
zeros = np.argwhere(img == 0)

dist_matrix = distance_matrix(zeros, not_zeros, p=2)
output = np.zeros((HEIGHT, WIDTH, 1), dtype=np.uint8)

i = 0
dist = np.min(dist_matrix, axis=1)
for y in range(HEIGHT):
  for x in range(WIDTH):
    if img[y,x] == 0:
      output[y,x] = dist[i]
      i += 1

...

For example, on the left is a mask and on the right is the corresponding weight map.

Image 1

The blacker the pixel, the higher is the weight of the exponential term. The loss function BCE changes only in one line pos_weight = beta / (1 - beta) + tf.exp(-tf.pow(weights, 2)). And to pass the weight matrix as input, one could use:

from functools import partial

def loss_function(y_true, y_pred, weights):
...

weight_input = Input(shape=(HEIGHT, WIDTH))
loss = partial(loss_function, weights=weight_input)

Overlap measures

Dice Loss / F1 score

The Dice coefficient is similar to the Jaccard Index (Intersection over Union, IoU):

where TP are the true positives, FP false positives and FN false negatives. We can see that .

The dice coefficient can also be defined as a loss function:

where and .

def dice_loss(y_true, y_pred):
  numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=-1)
  denominator = tf.reduce_sum(y_true + y_pred, axis=-1)

  return 1 - (numerator + 1) / (denominator + 1)

Adding one to the numerator and denominator is quite important. For example, when $p = \hat{p} = 0$, the result should be $0$. But without the “+1” term, we get $1 - \frac{2\cdot 0 \cdot 0}{0 + 0} = 1$.

The “+1” term has two effects: (1) shift the range from $[0, 1]$ to $[0, 0.5]$, (2) prevent $\text{DL}\left(p, \hat{p}\right) = 0$, when $p = 0$ and $\hat{p} > 0$. The disadvantage is when $p = 0$, we get $1 - \frac{1}{\hat{p} + 1} = \frac{\hat{p}}{\hat{p} + 1}$.

In an older version of the blog post, I defined DL as in the paper [4]. However, the current version handles better cases like $p = 1 = \hat{p}$.

All loss functions defined so far have always returned tensors. Another possibility is to return a single scalar for each image. This is especially popular when combining loss functions. DL can be redefined as follows:

“+1” is no longer necessary, because $p = \hat{p} = 0$ doesn’t need handling.

def dice_loss(y_true, y_pred):
  numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=(1,2,3))
  denominator = tf.reduce_sum(y_true + y_pred, axis=(1,2,3))

  return 1 - numerator / denominator

In general, dice loss works better on images than on single pixels. The same is also true for the next loss.

Tversky loss

Tversky index (TI) is a generalization of Dice’s coefficient. TI adds a weight to FP (false positives) and FN (false negatives).

Let . Then

which is just the regular Dice coefficient. Similarly to DL, the loss function can be defined as follows [5]:

def tversky_loss(beta):
  def loss(y_true, y_pred):
    numerator = tf.reduce_sum(y_true * y_pred, axis=-1)
    denominator = y_true * y_pred + beta * (1 - y_true) * y_pred + (1 - beta) * y_true * (1 - y_pred)

    return 1 - (numerator + 1) / (tf.reduce_sum(denominator, axis=-1) + 1)

  return loss

Lovász-Softmax

DL and TL simply relax the hard constraint in order to have a function on the domain . The paper [6] derives instead a surrogate loss function.

An implementation of Lovász-Softmax can be found on github. Note that this loss requires the identity activation in the last layer. A negative value means class A and a positive value means class B.

In Keras the loss function can be used as follows:

def lovasz_softmax(y_true, y_pred):
  return lovasz_hinge(labels=y_true, logits=y_pred)

model.compile(loss=lovasz_softmax, optimizer=optimizer, metrics=[pixel_iou])

Combinations

It is also possible to combine multiple loss functions. The following function is quite popular in data competitions:

Note that $\text{CE}$ returns a tensor, while $\text{DL}$ returns a scalar for each image in the batch. This way we combine local ($\text{CE}$) with global information ($\text{DL}$).

def loss(y_true, y_pred):
    def dice_loss(y_true, y_pred):
        numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=(1,2,3))
        denominator = tf.reduce_sum(y_true + y_pred, axis=(1,2,3))

        return tf.reshape(1 - numerator / denominator, (-1, 1, 1))

    return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

Example: Let $\mathbf{P}$ be our real image, $\mathbf{\hat{P}}$ the prediction and $\mathbf{L}$ the result of the loss function.

Then , where

The result is:

References

[1] S. Xie and Z. Tu. Holistically-Nested Edge Detection, 2015.

[2] T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar. Focal Loss for Dense Object Detection, 2017.

[3] O. Ronneberger, P. Fischer, and T. Brox. U-Net: Convolutional Networks for Biomedical Image Segmentation, 2015.

[4] F. Milletari, N. Navab, and S.-A. Ahmadi. V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation, 2016.

[5] S. S. M. Salehi, D. Erdogmus, and A. Gholipour. Tversky loss function for image segmentation using 3D fully convolutional deep networks, 2017.

[6] M. Berman, A. R. Triki, M. B. Blaschko. The Lovász-Softmax loss: A tractable surrogate for the optimization of the intersection-over-union measure in neural networks, 2018.

Comments