Plotting top loss images while training models
Introduction
This report is authored by Tulasi Ram Laghumavarapu and Sayak Paul.
Data is the better half of machine learning models. This is why it is important to make sure that our models are exposed to quality data while they are training. Real data is noisy, unclean, and often confusing. In our lives also, lessons that are too confusing to grasp make it extremely difficult for us to actually learn a respective topic. The same drill applies to machine learning models too.
In this report, we are going to show you how to investigate the data points from a training set that make a model incur very high losses.
Acknowledgement: This technique is heavily motivated by the fast.ai library and we picked it up from Jeremy Howard's lectures presented in the Practical Deep Learning for Coders course.
Check out the code on GitHub →
Problem motivation
So, what we are doing here?
The idea is extremely simple. As our model is training, we are simply plotting the images along with some useful information that cause the model to incur high losses during the training process. With this utility and seen above, we could immediately plot two issues in the training dataset -
- There are some images that are confusing to recognize even for human eyes and it is making the model training process a bit inefficient.
- There is definitely label noise in the training dataset which is also one of the reasons many models lead to miserable performance.
Apart from these, with this technique, we can also start to get a sense of the following:
- Often structures of images belonging to different classes vary but if a structure is really unique and if that is not available in enough numbers of the model to learn, then also a model can get confused about it (for example, the sixth image in the above plot).
Deep learning shines on large volumes of (quality) data and it can get really intractable to manually investigate each of the data points from our datasets to catch for issues like the above. So, for these situations too, this technique can be really helpful. - When fine-tuning on different datasets with a pre-trained network as a starting point, the pre-trained weights can get broken too fast due to corrupted data points. In those cases, this little debugging technique can be helpful.
(Some might argue that to reduce issues like memorization in deep neural networks, it is often useful to regularize them with noisy labels but we will leave that discussion for some other time.)
What we can do to mitigate some of these problems?
Well, we cannot possibly cover all the nifty little things that deep learning practitioners do in order to tackle these problems but in our experience, the following techniques can be extremely useful for mitigating some of the above-mentioned problems -
- Label smoothing:
- Label correction:
- Discarding confusing data points: