Interpretability in Deep Learning With W&B - CAM and GradCAM

This report reviews how Grad-CAM counters the common criticism that neural networks are not interpretable. Made by Ayush Thakur using Weights & Biases
Ayush Thakur

Introduction

Training a classification model is interesting, but have you ever wondered how your model is making its predictions? Is your model actually looking at the dog in the image before classifying it as a dog with 98% accuracy? Interesting, isn't it. In today’s report, we will explore why deep learning models need to be interpretable, and some interesting methods to peek under the hood of a deep learning model. Deep learning interpretability is a very exciting area of research and much progress is being made in this direction already.
So why should you care about interpretability? After all, the success of your business or your project is judged primarily by how good the accuracy of your model is. But in order to deploy our models in the real world, we need to consider other factors too. For instance, is racially biased? Or, what if it’s classifying humans with 97% accuracy, but while it classifies men with 99% accuracy, it only achieves 95% accuracy on women?
Understanding how a model makes its predictions can also help us debug your network. [Check out this blog post on 'Debugging Neural Networks with PyTorch and W&B Using Gradients and Visualizations' for some other techniques that can help].
At this point, we are all familiar with the concept that deep learning models make predictions based on the learned representation expressed in terms of other simpler representations. That is, deep learning allows us to build complex concepts out of simpler concepts. Here’s an amazing Distill Pub post to help you understand this concept better. We also know that these representations are learned while we train the model with our input data and the label, in case of some supervised learning task like image classification. One of the criticisms of this approach is that the learned features in a neural network are not interpretable.
Today we'll look at 3 techniques that address this criticism and shed light on neural networks' “black-box” nature of learning.

Code →

Visualizing Activations

Neural networks trained to solve image classification problems, by default, have at least one convolutional layer. We will focus our attention on image classification and try to get the intuition on what these features are in deep learning literature.
The most straightforward and easiest way to make some sense of what’s going on inside a neural network is to visualize the activations of the intermediate layers during forward pass. For this report, I built a simple Cat and Dog image classifier. We are interested in visualizing the output of the first convolutional layer. To do so I built a simple FeatureLogger Keras Callback with W&B integration.

Implement Feature Logger

Class Activation Maps

It has been observed that convolution units of various layers of a convolutional neural network act as an object detector even though no such prior about the location of the object is provided while training the network for a classification task. Even though convolution has this remarkable property, it is lost when we use a fully connected layer for the classification task. To avoid the use of a fully connected network some architectures like Network in Network(NiN) and GoogLeNet are fully convolutional neural networks.
Global Average Pooling(GAP) is a very commonly used layer in such architectures. It is mainly used as a regularizer to prevent overfitting while training. The authors of Learning Deep Features for Discriminative Localization found out that by tweaking such an architecture, they can extend the advantages of GAP and can retain its localization ability until the last layer. Let’s try to quickly understand the procedure of generating CAM using GAP.
The class activation map simply indicates the discriminative region in the image which the CNN uses to classify that image in a particular category. For this technique, the network consists of ConvNet and just before the Softmax layer(for multi-class classification), global average pooling is performed on the convolutional feature maps. The output of this layer is used as features for a fully-connected layer that produces the desired classification output. Given this simple connectivity structure, we can identify the importance of the image regions by projecting back the weights of the output layer onto the convolutional feature maps.
Figure 1: The network architecture ideal for CAM (Source)
Let’s try to implement this. 😄

Step 1: Modify Your Model

Suppose you have built your deep classifier with Conv blocks and a few fully connected layers. We will have to modify this architecture such that there aren't any fully connected layers. We will use the GlobalAveragePooling2D layer between the output layer (softmax/sigmoid) and the last convolutional block.

Gradient-Weighted Class Activation Maps

Even though CAM was amazing it had some limitations:
What makes a good visual explanation?:
Thus the authors of Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, a really amazing paper, came up with modifications to CAM and previous approaches. Their approach uses the gradients of any target prediction flowing into the final convolutional layer to produce a coarse localization map highlighting the important regions in the image for predicting the class of the image.
Thus Grad-CAM is a strict generalization over CAM. Besides overcoming the limitations of CAM it's applicable to different deep learning tasks involving CNNs:
Let's implement this 😄

Conclusion

Class Activation Maps and Grad-CAMs are a few approaches that introduce some explainability/interpretability into deep learning models and are quite widely used. What's most fascinating about these techniques is the ability to perform the object localization task, even without training the model with a location prior. GradCAM, when used for image captioning, can help us understand what region in the image is used to generate a certain word. When used for a Visual Q&A task, it can help us understand why the model came to a particular answer. Even though Grad-CAM is class-discriminative and localizes the relevant image regions, it lacks the ability to highlight fine-grained details the way pixel-space gradient visualization methods like Guided backpropagation, and Deconvolution do. Thus the authors combined Grad-CAM with Guided backpropagation.
Thanks for reading this report until the end. I hope you find the callbacks introduced helpful for your deep learning wizardry. Please feel free to reach out to me on Twitter(@ayushthakur0) for any feedback on this report. Thank you.

Try it out in a Google Colab →