My personal blog

Machine learning, computer vision, languages

Loss Functions For Segmentation

27 Sep 2018

In this post, I will implement some of the most common loss functions for image segmentation in Keras/TensorFlow. I will only consider the case of two classes (i.e. binary).

01.09.2020: rewrote lots of parts, fixed mistakes, updated to TensorFlow 2.3

16.08.2019: improved overlap measures, added CE+DL loss

Cross Entropy

We have two probability distributions:

  1. The prediction can either be \(\mathbf{P}(\hat{Y} = 0) = \hat{p}\) or \(\mathbf{P}(\hat{Y} = 1) = 1 - \hat{p}\).
  2. The ground truth can either be \(\mathbf{P}(Y = 0) = p\) or \(\mathbf{P}(Y = 1) = 1 - p\).

The predictions are given by the logistic/sigmoid function \(\hat{p} = \frac{1}{1 + e^{-x}}\) and the ground truth is \(p \in \{0,1\}\).

Then cross entropy (CE) can be defined as follows:

\[\text{CE}\left(p, \hat{p}\right) = -\left(p \log\left(\hat{p}\right) + (1-p) \log\left(1 - \hat{p}\right)\right)\]

In Keras, the loss function is BinaryCrossentropy and in TensorFlow, it is sigmoid_cross_entropy_with_logits. For multiple classes, it is softmax_cross_entropy_with_logits_v2 and CategoricalCrossentropy/SparseCategoricalCrossentropy. Due to numerical stability, it is always better to use BinaryCrossentropy with from_logits=True. However, then the model should not contain the layer tf.keras.layers.Sigmoid() or tf.keras.layers.Softmax().

You can see in the original code that TensorFlow sometimes tries to compute cross entropy from probabilities (when from_logits=False). Due to numerical instabilities clip_by_value becomes then necessary.

In this post, I will always assume that tf.keras.layers.Sigmoid() is not applied (or only during prediction).

Weighted cross entropy

Weighted cross entropy (WCE) is a variant of CE where all positive examples get weighted by some coefficient. It is used in the case of class imbalance. In segmentation, it is often not necessary. However, it can be beneficial when the training of the neural network is unstable. In classification, it is mostly used for multiple classes. This is why TensorFlow has no function tf.nn.weighted_binary_entropy_with_logits. There is only tf.nn.weighted_cross_entropy_with_logits.

WCE can be defined as follows:

\[\text{WCE}\left(p, \hat{p}\right) = -\left(\beta p \log\left(\hat{p}\right) + (1-p) \log\left(1 - \hat{p}\right)\right)\]

To decrease the number of false negatives, set \(\beta > 1\). To decrease the number of false positives, set \(\beta < 1\).

The implementation looks as follows

def weighted_cross_entropy(beta):
  def loss(y_true, y_pred):
    weight_a = beta * tf.cast(y_true, tf.float32)
    weight_b = 1 - tf.cast(y_true, tf.float32)
    
    o = (tf.math.log1p(tf.exp(-tf.abs(y_pred))) + tf.nn.relu(-y_pred)) * (weight_a + weight_b) + y_pred * weight_b 
    return tf.reduce_mean(o)

  return loss

Loss functions can be set when compiling the model (Keras):

model.compile(loss=weighted_cross_entropy(beta=beta), optimizer=optimizer, metrics=metrics)

If you are wondering why there is a ReLU function, this follows from simplifications. I derive the formula in the section on focal loss.

The result of a loss function is always a scalar. Some deep learning libraries will automatically apply reduce_mean or reduce_sum if you don’t do it. When combining different loss functions, sometimes the axis argument of reduce_mean can become important. Since TensorFlow 2.0, the class BinaryCrossentropy has the argument reduction=losses_utils.ReductionV2.AUTO.

Balanced cross entropy

Balanced cross entropy (BCE) is similar to WCE. The only difference is that we weight also the negative examples.

BCE can be defined as follows:

\[\text{BCE}\left(p, \hat{p}\right) = -\left(\beta p \log\left(\hat{p}\right) + (1 - \beta)(1-p) \log\left(1 - \hat{p}\right)\right)\]

It can be implemented as follows:

def balanced_cross_entropy(beta):
  def loss(y_true, y_pred):
    weight_a = beta * tf.cast(y_true, tf.float32)
    weight_b = (1 - beta) * tf.cast(1 - y_true, tf.float32)
    
    o = (tf.math.log1p(tf.exp(-tf.abs(y_pred))) + tf.nn.relu(-y_pred)) * (weight_a + weight_b) + y_pred * weight_b
    return tf.reduce_mean(o)

  return loss

Instead of using a fixed value like beta = 0.3, it is also possible to dynamically adjust the value of beta. For example, the paper [1] uses: beta = tf.reduce_mean(1 - y_true)

Focal loss

Focal loss (FL) [2] tries to down-weight the contribution of easy examples so that the CNN focuses more on hard examples.

FL can be defined as follows:

\[\text{FL}\left(p, \hat{p}\right) = -\left(\alpha (1 - \hat{p})^{\gamma} p \log\left(\hat{p}\right) + (1 - \alpha) \hat{p}^{\gamma} (1-p) \log\left(1 - \hat{p}\right)\right)\]

When \(\gamma = 0\), we obtain BCE.

There are a lot of simplifications possible when implementing FL. TensorFlow uses the same simplifications for sigmoid_cross_entropy_with_logits (see the original code)

\[\begin{aligned} \text{FL}\left(p, \hat{p}\right) &= \alpha(1 - \hat{p})^{\gamma} p \log\left(1 + e^{-x}\right) - \left(1 - \alpha\right)\hat{p}^{\gamma}(1-p) \log\left(\frac{e^{-x}}{1 + e^{-x}}\right)\\ &= \alpha(1 - \hat{p})^{\gamma}p \log\left(1 + e^{-x}\right) - \left(1 - \alpha\right)\hat{p}^{\gamma}\left(1-p\right)\left(-x - \log\left(1 + e^{-x}\right)\right)\\ &= \alpha(1 - \hat{p})^{\gamma}p \log\left(1 + e^{-x}\right) + \left(1 - \alpha\right)\hat{p}^{\gamma}\left(1-p\right)\left(x + \log\left(1 + e^{-x}\right)\right)\\ &= \log\left(1 + e^{-x}\right)\left(\alpha (1 - \hat{p})^{\gamma} p + (1-\alpha)\hat{p}^{\gamma}(1-p)\right) + x(1 - \alpha)\hat{p}^{\gamma}(1 - p)\\ &= \log\left(e^{-x}(1 + e^{x})\right)\left(\alpha (1 - \hat{p})^{\gamma} p + (1-\alpha)\hat{p}^{\gamma}(1-p)\right) + x(1 - \alpha)\hat{p}^{\gamma}(1 - p)\\ &= \left(\log\left(1 + e^{x}\right) - x\right)\left(\alpha (1 - \hat{p})^{\gamma} p + (1-\alpha)\hat{p}^{\gamma}(1-p)\right) + x(1 - \alpha)\hat{p}^{\gamma}(1 - p)\\ &= \left(\log\left(1 + e^{-|x|}\right) + \max(-x, 0)\right)\left(\alpha (1 - \hat{p})^{\gamma} p + (1-\alpha)\hat{p}^{\gamma}(1-p)\right) + x(1 - \alpha)\hat{p}^{\gamma}(1 - p)\\ \end{aligned}\]

And the implementation is then:

def focal_loss(alpha=0.25, gamma=2):
  def focal_loss_with_logits(logits, targets, alpha, gamma, y_pred):
    targets = tf.cast(targets, tf.float32)
    weight_a = alpha * (1 - y_pred) ** gamma * targets
    weight_b = (1 - alpha) * y_pred ** gamma * (1 - targets)
    
    return (tf.math.log1p(tf.exp(-tf.abs(logits))) + tf.nn.relu(-logits)) * (weight_a + weight_b) + logits * weight_b 

  def loss(y_true, logits):
    y_pred = tf.math.sigmoid(logits)
    loss = focal_loss_with_logits(logits=logits, targets=y_true, alpha=alpha, gamma=gamma, y_pred=y_pred)

    return tf.reduce_mean(loss)

  return loss

Distance to the nearest cell

The paper [3] adds to cross entropy a distance function to force the CNN to learn the separation border between touching objects. In other words, this is BCE with an additional distance term:

\[\text{DNC}\left(p, \hat{p}\right) = -\left(w(p) p \log\left(\hat{p}\right) + w(p)(1-p) \log\left(1 - \hat{p}\right)\right)\]

where

\[w(p) = w_c(p) + w_0\cdot\exp\left(-\frac{(d_1(p) + d_2(p))^2}{2\sigma^2}\right)\]

\(d_1(x)\) and \(d_2(x)\) are two functions that calculate the distance to the nearest and second nearest cell and \(w_c(p) = \beta\) or \(w_c(p) = 1 - \beta\). If we had multiple classes, then \(w_c(p)\) would return a different \(\beta_i\) depending on the class \(i\). The values \(w_0\), \(\sigma\), \(\beta\) are all parameters of the loss function (some constants).

Calculating the exponential term inside the loss function would slow down the training considerably. Hence, it is better to precompute the distance map and pass it to the neural network together with the image input.

The following code is a variation that calculates the distance only to one object.

from scipy.spatial import distance_matrix
import numpy as np

...

not_zeros = np.argwhere(img != 0)
zeros = np.argwhere(img == 0)

dist_matrix = distance_matrix(zeros, not_zeros, p=2)
output = np.zeros((HEIGHT, WIDTH, 1), dtype=np.uint8)

i = 0
dist = np.min(dist_matrix, axis=1)
for y in range(HEIGHT):
  for x in range(WIDTH):
    if img[y,x] == 0:
      output[y,x] = dist[i]
      i += 1

...

For example, on the left is a mask and on the right is the corresponding weight map.

...
Mask and weight map in comparison

The blacker the pixel, the higher is the weight of the exponential term. To pass the weight matrix as input, one could use:

from functools import partial

def loss_function(y_true, y_pred, weights):
...

weight_input = Input(shape=(HEIGHT, WIDTH))
loss = partial(loss_function, weights=weight_input)

Overlap measures

Dice Loss / F1 score

The Dice coefficient is similar to the Jaccard Index (Intersection over Union, IoU):

\[\text{DC} = \frac{2 TP}{2 TP + FP + FN} = \frac{2|X \cap Y|}{|X| + |Y|}\] \[\text{IoU} = \frac{TP}{TP + FP + FN} = \frac{|X \cap Y|}{|X| + |Y| - |X \cap Y|}\]

where TP are the true positives, FP false positives and FN false negatives. We can see that \(\text{DC} \geq \text{IoU}\).

The dice coefficient can also be defined as a loss function:

\[\text{DL}\left(p, \hat{p}\right) = 1 - \frac{2\sum p_{h,w}\hat{p}_{h,w}}{\sum p_{h,w} + \sum \hat{p}_{h,w}}\]

where \(p_{h,w} \in \{0,1\}\) and \(0 \leq \hat{p}_{h,w} \leq 1\).

The code is then

def dice_loss(y_true, y_pred):
  y_true = tf.cast(y_true, tf.float32)
  y_pred = tf.math.sigmoid(y_pred)
  numerator = 2 * tf.reduce_sum(y_true * y_pred)
  denominator = tf.reduce_sum(y_true + y_pred)

  return 1 - numerator / denominator

In general, dice loss works better when it is applied on images than on single pixels. This means \(1 - \frac{2p\hat{p}}{p + \hat{p}}\) is never used for segmentation.

Tversky loss

Tversky index (TI) is a generalization of the Dice coefficient. TI adds a weight to FP (false positives) and FN (false negatives).

\[\text{TI}\left(p, \hat{p}\right) = 1 - \frac{p\hat{p}}{p\hat{p} + \beta(1 - p)\hat{p} + (1 - \beta)p(1 - \hat{p})}\]

Let \(\beta = \frac{1}{2}\). Then

\[\begin{aligned} &= 1 - \frac{2 p\hat{p}}{2p\hat{p} + (1 - p)\hat{p} + p (1 - \hat{p})}\\ &= 1 - \frac{2 p\hat{p}}{\hat{p} + p} \end{aligned}\]

which is just the regular Dice coefficient. Since we are interested in sets of pixels, the following function computes the sum of pixels [5]:

def tversky_loss(beta):
  def loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.math.sigmoid(y_pred)
    numerator = y_true * y_pred
    denominator = y_true * y_pred + beta * (1 - y_true) * y_pred + (1 - beta) * y_true * (1 - y_pred)

    return 1 - tf.reduce_sum(numerator) / tf.reduce_sum(denominator)

  return loss

Lovász-Softmax

DL and TL simply relax the hard constraint \(p \in \{0,1\}\) in order to have a function on the domain \([0, 1]\). The paper [6] derives instead a surrogate loss function.

An implementation of Lovász-Softmax can be found on github. Note that this loss does not rely on the sigmoid function (“hinge loss”). A negative value means class A and a positive value means class B.

In Keras the loss function can be used as follows:

def lovasz_softmax(y_true, y_pred):
  return lovasz_hinge(labels=y_true, logits=y_pred)

model.compile(loss=lovasz_softmax, optimizer=optimizer, metrics=[pixel_iou])

Combinations

It is also possible to combine multiple loss functions. The following function is quite popular in data competitions:

\[\text{CE}\left(p, \hat{p}\right) + \text{DL}\left(p, \hat{p}\right)\]

Note that \(\text{CE}\) returns a tensor, while \(\text{DL}\) returns a scalar for each image in the batch. This way we combine local (\(\text{CE}\)) with global information (\(\text{DL}\)).

def loss(y_true, y_pred):
    def dice_loss(y_true, y_pred):
      y_pred = tf.math.sigmoid(y_pred)
      numerator = 2 * tf.reduce_sum(y_true * y_pred)
      denominator = tf.reduce_sum(y_true + y_pred)

      return 1 - numerator / denominator

    y_true = tf.cast(y_true, tf.float32)
    o = tf.nn.sigmoid_cross_entropy_with_logits(y_true, y_pred) + dice_loss(y_true, y_pred)
    return tf.reduce_mean(o)

Some people additionally apply the logarithm function to dice_loss.

Example: Let \(\mathbf{P}\) be our real image, \(\mathbf{\hat{P}}\) the prediction and \(\mathbf{L}\) the result of the loss function.

\[\mathbf{P} = \begin{bmatrix}1 & 1\\0 & 0\end{bmatrix}\] \[\mathbf{\hat{P}} = \begin{bmatrix}0.5 & 0.6\\0.2 & 0.1\end{bmatrix}\]

Then \(\mathbf{L} = \begin{bmatrix}-1\log(0.5) + l_2 & -1\log(0.6) + l_2\\-(1 - 0)\log(1 - 0.2) + l_2 & -(1 - 0)\log(1 - 0.1) + l_2\end{bmatrix}\), where

\[l_2 = 1 - \frac{2(1 \cdot 0.5 + 1 \cdot 0.6 + 0 \cdot 0.2 + 0 \cdot 0.1)}{(1 + 1 + 0 + 0) + (0.5 + 0.6 + 0.2 + 0.1)} \approx 0.3529\]

The result is:

\[\mathbf{L} \approx \begin{bmatrix}0.6931 + 0.3529 & 0.5108 + 0.3529\\0.2231 + 0.3529 & 0.1054 + 0.3529\end{bmatrix} = \begin{bmatrix}1.046 & 0.8637\\0.576 & 0.4583\end{bmatrix}\]

Next, we compute the mean via tf.reduce_mean which results in \(\frac{1}{4}(1.046 + 0.8637 + 0.576 + 0.4583) = 0.736\)

Let’s check the result:

c = tf.constant([[1.0, 1.0], [0.0, 0.0]])
d = tf.constant([[0.5, 0.6], [0.2, 0.1]])

print(loss(c, tf.math.log(d / (1 - d))))
# tf.Tensor(0.7360604, shape=(), dtype=float32)

References

[1] S. Xie and Z. Tu. Holistically-Nested Edge Detection, 2015.

[2] T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar. Focal Loss for Dense Object Detection, 2017.

[3] O. Ronneberger, P. Fischer, and T. Brox. U-Net: Convolutional Networks for Biomedical Image Segmentation, 2015.

[4] F. Milletari, N. Navab, and S.-A. Ahmadi. V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation, 2016.

[5] S. S. M. Salehi, D. Erdogmus, and A. Gholipour. Tversky loss function for image segmentation using 3D fully convolutional deep networks, 2017.

[6] M. Berman, A. R. Triki, M. B. Blaschko. The Lovász-Softmax loss: A tractable surrogate for the optimization of the intersection-over-union measure in neural networks, 2018.