Debug Models By Plotting The Top Loss Images
Introduction
In this report, we'll show you how to find the images in your training dataset that make the model incur disproportionally high losses. Investigating datapoints that incur high losses is an essential step in debugging our deep learning models.
Acknowledgement: This technique is heavily motivated by the fast.ai library, and we picked it up from Jeremy Howard's lectures presented in the Practical Deep Learning for Coders course.
Check out the code on GitHub →
Problem Motivation
Let us take the canonical MNIST training examples and a shallow CNN model to motivate the problem. When we try to fit this model to the training images, we observe the following:
- If you see the fourth image from the plot below, we can see that the model predicts it as 7 while the original class of the image is 1. Even to us humans, we are almost certain that the digit looks like a 1.
- If you see the eigth image, it is a clear sign of label noise, as the image looks like 8 but is labeled as 5.
It's essential to take note that the model made predictions with 100% certainty. As you may have noticed, the projections statistics are collected from the last epoch (you can slide across different epochs with the little driver button you see on the top-left). We analyze these from the model training previous time, and we can at least say that model is not incurring this because it has not learned anything.
So, What Are We Doing Here?
As we train the model, we plot the images along with some useful information that causes the model to incur high losses during training. With this utility, we immediately plot two issues in the training dataset.
- Some images are confusing to recognize even for human eyes, making the process of model training inefficient.
- There is label noise in the training dataset, which is also one reason many models lead to miserable performance.
We also make the following observations:
- Often, structures of images belonging to different classes vary, but if a structure is unique and not widely available in the dataset for the model to learn, then the model can get confused about it (for example, the sixth image in the above plot).
Deep learning succeeds in vast quantities of (quality) data, so it may become intractable to manually analyze each of the data points from our datasets to search for problems mentioned above. This method can be useful for such cases too.
(Some might argue that to reduce issues like memorization in deep neural networks, it is often useful to regularize networks with noisy labels, but we will leave that discussion for some other time.)
What can we do to mitigate some of these problems?
We may not be able to cover all the tiny details that practitioners do to deal with these issues. Still, the structural measures can be beneficial in our experience to reduce some of the issues mentioned above.
Label Smoothing
If you see below, the model is predicting all the images with the actual label 7 as 1. One obvious thing we can do is changing the label to 7. But we call it hard label assignment. The corresponding one-hot encoded vector would be like this [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0]. Even though I am 100% sure it is 7, I would instead give 90 percent to the value 7 and divide the remaining 10 percent to all classes. So the one-hot encoded vector would look like this [0.01 0.11 0.01 0.01 0.01 0.01 0.01 0.91 0.01 0.01]. This simple technique prevents a model to not get too confident about its predictions.
Acknowledgment: This explanation is referred from this blog post by PyImageSearch.
Label Correction
Label correction helps the model train better without fewer confusions (less because there can be several other sources in which a deep learning model can get confused). Following this approach, first, we can correct a few noisy labels (as shown below) in the training dataset and retrain our model to see if the confusion improves. Although this method might not apply to large-scale datasets, experiments have shown that even a small amount of correction of labels can improve a model's performance.
Discarding Confusing Data Points
Discarding the confusing data points also helps the model to train better. Following this approach, we can drop in-depth complicated data points from the training dataset and retrain our model to see if the performance improves. This method is generally applied when the dataset is sufficiently big enough. In the next section, we show if label-smoothing can help us clear some of the problems our model is having.
Does Label Smoothing Improve Anything?
Label Smoothing Factor Of 0.2 Helped The Model
Conclusion
We hope this simple yet productive debugging technique will find its way to your deep learn-depth practitioner toolkit. Let us know if you have any feedback to share via GitHub issues. You can find the code by going to the link below and reproduce the results.