Loss function for Image Segmentation Tasks!
Nov. 28, 2023, 7:57 p.m.
Imagine you are training a computer to understand pictures, like teaching it to recognize objects in an image. Now, the "loss function" is like a guide telling the computer how it's doing. It helps the computer learn and get better at its job.
In our journey today, we're zooming into a specific area called "semantic segmentation." This is like teaching the computer to not just see the whole picture but to understand each tiny dot (pixel) and say what it is.
Now, the cool part is we've got seven different guides (loss functions) that help the computer learn better. Some focus on how things are spread out in the picture, some on specific regions, and some are a mix of both. It's like having different tools in a toolbox to build a smart computer vision system.
Following the list of the topics in this blog-post:
Ready to explore these loss functions? Let's go!
Cross entropy loss is one of the most widley used loss function in deep learning tasks. The basic concept of cross entropy loss is to measure the difference between different probability distribution. For binary classification (in segmentation domain, it is when we have only background and one class object to segment) the formulation of the, in this case Binary Cross Entropy Loss, is the follow:
BCEloss(y, ŷ) = - (y * log(ŷ) + (1 - y) * log(1 - ŷ))
- y: The true binary label (0 or 1).
- ŷ: The predicted probability corresponding to class 1 (from 0 to 1).
- log(ŷ): The logarithm of the predicted probability.
- log(1 - ŷ): The logarithm of the complement of the predicted probability.
We can extend the BCEloss by treating the multi-class labels as one-hot encoding labels, and the formulation in this case is the follow:
MCEloss(y, ŷ)=−∑Ci=1 yi log(ŷi)
- y: One-hot encoded true class labels (a vector of length C where C is the number of classes).
- ŷ: Predicted class probabilities (a vector of length C representing the model's confidence scores for each class).
- log(ŷi): The logarithm of the predicted probability of class i.
The goal of cross-entropy loss is to reduce errors on a pixel-by-pixel basis. However, in situations where there is an imbalance in the distribution of classes, this method tends to emphasize larger objects, neglecting smaller ones. This imbalance can lead to suboptimal segmentation quality for smaller objects.
For a clearer comprehension, let's delve into an implementation example using the PyTorch library, one of the most popular framework in Deep Learning field:
# The target is a 10x10 mask with binary classification, where 1 is the object to segment and 0 is the beckground
'''
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 1 1 1 1 0 0 0
0 0 1 1 1 1 1 1 0 0
0 0 1 1 1 1 1 1 0 0
0 0 1 1 1 1 1 1 0 0
0 0 0 1 1 1 1 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0
'''
# Create the examples tensors
H = W = 10
ridx = [2]*4 + [3]*6 + [4]*6 + [5]*6 + [6]*4 #target row index where value is 1
cidx = [j for j in range(3,7)] + [j for j in (range(2,8))]*3 + [j for j in range(3,7)] #target columns index where value is 1
predictions = torch.randn((H,W))*10 # random values from -1 to 1 multiplied to 10 as prediction Logits
target = torch.zeros_like(predictions)
target[ridx,cidx] = 1
#Calculating the Binary cross entropy loss
#First we convert the logits to probability using sigmoid function
prediction_probability = torch.sigmoid(predictions)
#In this case our logits value became 0 to 1 value and indicates the probability to the pixel location to be class 0 or class 1
#We create the bce_loss function that takes as input the predicted probabilities and the target (the true labels)
def bce_loss(p_prob,target):
return -(target*torch.log(p_prob)+(1-target)*torch.log(1-p_prob)
# We execute the function to see what is the result
bce_loss(prediction_probability,target)
# We obtain a matrix with log loss for each pixel. In the image below the values are clamped to avoid the inf value.
FOCAL LOSS
The Focal Loss plays a crucial role in mitigating class imbalance during the training phase, particularly in tasks like pixel-wise classification. This specialized loss function introduces a modulating term to the standard cross entropy loss, strategically directing the learning process towards challenging misclassifications inherent in semantic segmentation.
Essentially, the Focal Loss dynamically scales the cross entropy loss, with the scaling factor diminishing as the model gains confidence in accurately assigning semantic labels. This adaptive scaling factor automatically de-prioritizes the impact of easily classified pixels during training, swiftly guiding the model's attention towards intricacies in segmentation, such as accurately delineating boundaries and handling ambiguous regions. In the figure below, you can observe how the loss value decreases for higher probabilities as the lambda coefficient increases.
The formulation of the Focal loss for a multi-class classification problem is the following:
Floss(y, ŷ) = α (1 − y)ᵞ · MCEloss(p, ŷ)
- y: One-hot encoded true class labels (a vector of length C where C is the number of classes).
- ŷ: Predicted class probabilities (a vector of length C representing the model's confidence scores for each class).
- α: class weights (a vector of length C representing the weight of each class).
- ᵞ: modulating term gamma.
- MCEloss: Standard Multi-class Cross Entropy loss.
The Dice Loss, is based on the Dice similarity coefficient (DSC) a metric used to quantify the similarity between two sets. It is commonly employed in the field of medical image analysis.
The Dice coefficient yields values in the range [0, 1], where 0 indicates no overlap between the sets, and 1 represents perfect overlap or complete similarity.
In the context of image segmentation, the sets A and B often represent the binary masks of the predicted and ground truth segmentation maps. A higher Dice coefficient suggests a better agreement between the predicted and actual segmentation.
The Dice Similarity Coefficient is particularly favored when dealing with imbalanced datasets, as it is less sensitive to class imbalance than some other metrics. The formulation of DSC is as follow:
-
True Positive (TP): The number of pixels that are correctly classified as positive by the model.
-
False Positive (FP): The number of pixels that are predicted as positive but are actually negative.
-
False Negative (FN): The number of pixels that are predicted as negative but are actually positive.
The formulation for the loss is as follow:
Dloss(y, ŷ) = 1 - DSC(y, ŷ)
- y: One-hot encoded true class labels (a vector of length C where C is the number of classes).
- ŷ: Predicted class probabilities (a vector of length C representing the model's confidence scores for each class).
The Tversky loss is an extension of the Tversky index, a statistical measure of the similarity between two sets. In the context of image segmentation, the Tversky loss is utilized as a cost function during the training of neural networks.
The formula for the Tversky index, and consequently the Tversky loss, is as follows:
Here, α and β are hyperparameters that control the weight assigned to false positives and false negatives, respectively. When α and β are equal to 1, the Tversky Index is the same as Jaccard Index (IoU), when α and β are equal to 0.5 we obtain the DSC.
The formulation for the loss is as follow:
TVloss(y, ŷ) = 1 - TVindex(y, ŷ)
The Tversky loss is particularly useful in scenarios where there is a class imbalance, meaning that one class significantly outnumbers the others. By introducing the α and β parameters, the loss function allows for a fine-tuning of the emphasis on false positives and false negatives. This adaptability makes the Tversky loss effective in tasks such as medical image segmentation, where accurate delineation of specific structures is crucial.
Numerous research papers have delved into the pursuit of identifying the optimal loss function for image segmentation tasks. Interestingly, many of these proposed loss functions are intricate combinations of previously described methodologies. Below, you'll find links to papers that explore particularly compelling approaches to loss functions in the context of image segmentation:
- Focal Tversky Loss (Abraham & Khan, 2019) : Focal loss function based on the Tversky index to address the issue of data imbalance in medical image segmentation.
- Combo Loss (Taghanaki et al., 2019) : The Combo loss is an amalgamation of the Dice loss, Binary Cross-Entropy (BCE) loss, and the Focal loss. Each component plays a distinct role in refining the model's ability to handle both large and small objects within an image.
- Hybrid Focal loss (Yeung et al., 2021) : Integrates tunable parameters for output imbalance and focal parameters for input imbalance in both Dice and cross-entropy components. The Hybrid Focal loss (LHF) is defined by substituting the Dice loss with Focal Tversky loss and the cross-entropy loss with Focal loss.