Debug Models By Plotting The Top Loss Images

Debug and improve your models by finding the images in the training dataset that incur disproportionally high losses.
Tulasi Ram Laghumavarapu


In this report, we'll show you how to find the images in your training dataset that make the model incur disproportionally high losses. Investigating datapoints that incur high losses is an essential step in debugging our deep learning models.

Acknowledgement: This technique is heavily motivated by the library, and we picked it up from Jeremy Howard's lectures presented in the Practical Deep Learning for Coders course.

Check out the code on GitHub →

Problem Motivation

Let us take the canonical MNIST training examples and a shallow CNN model to motivate the problem. When we try to fit this model to the training images, we observe the following:

​​It's essential to take note that the model made predictions with 100% certainty. As you may have noticed, the projections statistics are collected from the last epoch (you can slide across different epochs with the little driver button you see on the top-left). ​​We analyze these from the model training previous time, and we can at least say that model is not incurring this because it has not learned anything.

Problem motivation

So, What Are We Doing Here?

As we train the model, we plot the images along with some useful information that causes the model to incur high losses during training. With this utility, we immediately plot two issues in the training dataset.

​​We also make the following observations:

​​Deep learning succeeds in vast quantities of (quality) data, so it may become intractable to manually analyze each of the data points from our datasets to search for problems mentioned above. This method can be useful for such cases too.

(Some might argue that to reduce issues like memorization in deep neural networks, it is often useful to regularize networks with noisy labels, but we will leave that discussion for some other time.)

What can we do to mitigate some of these problems?

​​We may not be able to cover all the tiny details that practitioners do to deal with these issues. Still, the structural measures can be beneficial in our experience to reduce some of the issues mentioned above.

Label Smoothing

If you see below, the model is predicting all the images with the actual label 7 as 1. One obvious thing we can do is changing the label to 7. But we call it hard label assignment. The corresponding one-hot encoded vector would be like this [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0]. Even though I am 100% sure it is 7, I would instead give 90 percent to the value 7 and divide the remaining 10 percent to all classes. So the one-hot encoded vector would look like this [0.01 0.11 0.01 0.01 0.01 0.01 0.01 0.91 0.01 0.01]. This simple technique prevents a model to not get too confident about its predictions.

Acknowledgment: This explanation is referred from this blog post by PyImageSearch.

Label Correction

Label correction helps the model train better without fewer confusions (less because there can be several other sources in which a deep learning model can get confused). Following this approach, first, we can correct a few noisy labels (as shown below) in the training dataset and retrain our model to see if the confusion improves. Although this method might not apply to large-scale datasets, experiments have shown that even a small amount of correction of labels can improve a model's performance.

Discarding Confusing Data Points

Discarding the confusing data points also helps the model to train better. Following this approach, we can drop in-depth complicated data points from the training dataset and retrain our model to see if the performance improves. This method is generally applied when the dataset is sufficiently big enough. In the next section, we show if label-smoothing can help us clear some of the problems our model is having.

What we can do to mitigate some of these problems?

Does Label Smoothing Improve Anything?

Does label smoothing improve anything?

Label Smoothing Factor Of 0.2 Helped The Model

Label smoothing factor 0.2 helped the model


We hope this simple yet productive debugging technique will find its way to your deep learn-depth practitioner toolkit. Let us know if you have any feedback to share via GitHub issues. You can find the code by going to the link below and reproduce the results.

Check out the code on GitHub →