Interpretability in Deep Learning With Weights & Biases: CAM and Grad-CAM
This article reviews how Grad-CAM counters the common criticism that neural networks are not interpretable.
Created on April 27|Last edited on October 11
Comment
Training a classification model is interesting, but have you ever wondered how your model is making its predictions? Is your model actually looking at the dog in the image before classifying it as a dog with 98% accuracy? Interesting isn't it?
In this article, we'll explore why deep learning models need to be interpretable, and some interesting methods to peek under the hood of a deep learning model. Deep learning interpretability is a very exciting area of research and much progress is being made in this direction already.
So, why should you care about interpretability? After all, the success of your business or your project is judged primarily by how good the accuracy of your model is. But in order to deploy our models in the real world, we need to consider other factors too. For instance, is racially biased? Or, what if it’s classifying humans with 97% accuracy, but while it classifies men with 99% accuracy, it only achieves 95% accuracy on women?
Understanding how a model makes its predictions can also help us debug your network. (Check out this blog post on 'Debugging Neural Networks with PyTorch and W&B Using Gradients and Visualizations' for some other techniques that can help.)
At this point, we are all familiar with the concept that deep learning models make predictions based on the learned representation expressed in terms of other simpler representations. That is, deep learning allows us to build complex concepts out of simpler concepts. Here’s an amazing Distill Pub post to help you understand this concept better. We also know that these representations are learned while we train the model with our input data and the label, in case of some supervised learning task like image classification. One of the criticisms of this approach is that the learned features in a neural network are not interpretable.
Today we'll look at 3 techniques that address this criticism and shed light on neural networks' “black-box” nature of learning.
- Visualize learned features.
- Class Activation Map (CAM)
- Gradient CAM
Table of Contents
Visualizing ActivationsImplement Feature LoggerObservations:Usage:Class Activation MapsStep 1: Modify Your ModelStep 2: Retrain Your Model With CAMLogger callbackGradient-Weighted Class Activation MapsStep 1: Your Deep Learning TaskStep 2: Use GRADCamLogger While Training Your ModelStep 3: Use GRADCamLogger and trainConclusion
Visualizing Activations
Neural networks trained to solve image classification problems, by default, have at least one convolutional layer. We will focus our attention on image classification and try to get the intuition on what these features are in deep learning literature.
The most straightforward and easiest way to make some sense of what’s going on inside a neural network is to visualize the activations of the intermediate layers during forward pass. For this report, I built a simple Cat and Dog image classifier. We are interested in visualizing the output of the first convolutional layer. To do so I built a simple FeatureLogger Keras Callback with W&B integration.
Implement Feature Logger
class FeaturesLogger(tf.keras.callbacks.Callback):def __init__(self, validation_data, layer_name):'''validation_data: tuple of form (sample_images, sample_labels).layer_name: string of the layer of whose features we are interested in.'''super(FeaturesLogger, self).__init__()self.validation_data = validation_dataself.layer_name = layer_namedef on_epoch_end(self, logs, epoch):## Build intermediate layer with the target layerself.intermediate_model =keras.models.Model(inputs=model.input,outputs=model.get_layer(layerName).output)## Unpack validation dataimages, labels = self.validation_datafor image, label in zip(images, labels):## Compute output activation of the provided layer nameimg = np.expand_dims(image, axis=0)features = self.intermediate_model.predict(img)features = features.reshape(features.shape[1:])features = np.rollaxis(features, 2, 0)## Preparea the plot to be logged to wandbfig, axs = plt.subplots(nrows=4, ncols=8, figsize=(15,8))c = 0for i in range(4):for j in range(8):axs[i][j].imshow(features[c], cmap='gray')axs[i][j].set_xticks([])axs[i][j].set_yticks([])c+=1wandb.log({"features_labels_{}".format(label): plt})plt.close()
After every training epoch, the FeatureLogger callback is evoked. You need to feed it one validation image from each label to visualize the activation and provide the layer name (here we are interested in the first layer, conv2d).
We build an intermediate_model to get the output activation of that layer. It’s really easy to do so in Keras. We are iterating over each sample and calling model.predict(). The output of this call is the output of that convolutional layer. In my case, it’s a NumPy array of shape (1,148,148,32).
Next, we do some post-processing. Notice the use of np.rollaxis, to convert the features array of shape (148,148,32) to (32,148,148). It simply means that there are 32 filters in the conv2d layer. Using W&B we can easily plot the plt object generated.
Let's look at the plots 👇
Run set 2
1
Observations:
- From the chart on our left 👈 we can say that the model wasn't able to fit well on the training data for the given network configuration and number of epochs.
- From the charts above ☝️ we can clearly say that the first convolutional layer is retaining the full shape of the input image. That is the initial layer(s) retain the spatial information of the image.
- As you go deeper, the layers begin to encode high-level representations. These representations tend to retain little to no visual content. They are rich with information regarding your labels.
Usage:
- One dangerous pitfall that can be easily noticed in this visualization – that some activation maps may be all zero for many different inputs, which can indicate dead filters, and can be a symptom of high learning rates.
Class Activation Maps
Step 1: Modify Your Model
Suppose you have built your deep classifier with Conv blocks and a few fully connected layers. We will have to modify this architecture such that there aren't any fully connected layers. We will use the GlobalAveragePooling2D layer between the output layer (softmax/sigmoid) and the last convolutional block.
def flatten_model(model_nested):'''Utility to flatten pretrained model'''layers_flat = []for layer in model_nested.layers:try:layers_flat.extend(layer.layers)except AttributeError:layers_flat.append(layer)return layers_flatdef CAMmodel():## Simulating my pretrained dog and cat classifier.vgg = VGG16(include_top=False, weights='imagenet')vgg.trainable = False## Flatten the layer so that it's not nested in the sequential model.vgg_flat = flatten_model(vgg)## Insert GAPvgg_flat.append(keras.layers.GlobalAveragePooling2D())vgg_flat.append(keras.layers.Dense(1, activation='sigmoid'))model = keras.models.Sequential(vgg_flat)return model
Initialize the Modified Model:
keras.backend.clear_session()model = CAMmodel()model.build((None, None, None, 3)) # Notemodel.summary()
The CAMmodel provides a required modification to our cat and dog classifier. Here I am using pre-trained VGG16 model to simulate my already trained cat-dog classifier.
A simple utility flatten_model returns the list of layers in my pre-trained model. This is done so that the layers are not nested when modified using Sequential model and the last convolutional layer can be accessed and used as an output. I appended GlobalAveragePooling2D and Dense in the returned array from flatten_model. Finally, the Sequential model is returned.
Next, we call model.build() with the appropriate model input shape.
Step 2: Retrain Your Model With CAMLogger callback
Since a new layer was introduced, we have to retrain the model. But we don’t need to retrain the entire model. We can freeze the convolutional blocks by using vgg.trainable=False
Observations:
- There is a decline in the model performance in terms of both training and validation accuracy. The optimal train and validation accuracy that I achieved was 99.01% and 95.67% respectively.
- Thus for the implementation of CAM, we have to modify our architecture and thus a decline in model performance.
Run set 2
1
Gradient-Weighted Class Activation Maps
Even though CAM was amazing it had some limitations:
- The model needs to be modified in order to use CAM.
- The modified model needs to be retrained, which is computationally expensive.
- Since fully connected Dense layers are removed. the model performance will surely suffer. This means the prediction score doesn't give the actual picture of the model's ability.
- The use case was bound by architectural constraints, i.e., architectures performing GAP over convolutional maps immediately before the output layer.
What makes a good visual explanation?:
- Certainly, the technique should localize the class in the image. We saw this in CAM and it worked remarkably well.
- Finer details should be captured, i.e., the activation map should be high resolution.
Thus the authors of Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, a really amazing paper, came up with modifications to CAM and previous approaches. Their approach uses the gradients of any target prediction flowing into the final convolutional layer to produce a coarse localization map highlighting the important regions in the image for predicting the class of the image.
Thus Grad-CAM is a strict generalization over CAM. Besides overcoming the limitations of CAM it's applicable to different deep learning tasks involving CNNs:
- CNNs with fully-connected layers (e.g. VGG) without any modification to the network.
- CNNs used for structured outputs like image captioning.
- CNNs used in tasks with multi-modal inputs like visual Q&A or reinforcement learning, without architectural changes or re-training.
Let's implement this 😄
def catdogmodel():inp = keras.layers.Input(shape=(224,224,3))vgg = tf.keras.applications.VGG16(include_top=False, weights='imagenet', input_tensor=inp,input_shape=(224,224,3))vgg.trainable = Falsex = vgg.get_layer('block5_pool').outputx = tf.keras.layers.GlobalAveragePooling2D()(x)x = keras.layers.Dense(64, activation='relu')(x)output = keras.layers.Dense(1, activation='sigmoid')(x)model = tf.keras.models.Model(inputs = inp, outputs=output)return model
Step 1: Your Deep Learning Task
We will focus on the image classification task. Unlike CAM we don't have to modify our model for this task and retrain it.
I have used a VGG16 model pre-trained on ImageNet as my base model and I'm simulating Transfer Learning with this.
The layers of the baseline model are turned to non-trainable by using vgg.trainable = False. Note how I have used fully connected layers in the model.
Step 2: Use GRADCamLogger While Training Your Model
You will find the class GradCAM in the linked notebook. This is a modified implementation from [Grad-CAM: Visualize class activation maps with Keras, TensorFlow, and Deep Learning](https://www.pyimagesearch.com/2020/03/09/grad-cam-visualize-class-activation-maps-with-keras-tensorflow-and-deep-learning/), an amazing blog post, by Adrian Rosebrook of PyImageSearch.com. I would highly suggest checking out the step-by-step implementation of the Grad-CAM class in that blog post.
I made two modifications to it:
- While doing transfer learning, that is, if your target (last) convolutional layer is non-trainable, tape.gradient(loss, convOutputs) will return None. This is because tape.gradient() by default does not trace non-trainable variables/layers. Thus to use that layer for computing your gradients you need to allow GradientTape to watch it by calling tape.watch() on the target layer output (tensor). Hence the change,
with tf.GradientTape() as tape:tape.watch(self.gradModel.get_layer(self.layerName).output)inputs = tf.cast(image, tf.float32)(convOutputs, predictions) = self.gradModel(inputs)
- The original implementation didn't account for binary classification. The original authors also talked about softmax-ing the output. So in order to train a simple cat and dog classifier, I made a small modification. Hence the change,
if len(predictions)==1:# Binary Classificationloss = predictions[0]else:loss = predictions[:, classIdx]
The GRAD-CAM class can be used after the model is trained or as a callback. Here's a small excerpt from his blog post.

The third motivated me to work on this project. I built a custom callback around this GRADCAM implementation and used wandb.log() to log the activation maps. Thus by using this callback you can use Grad-CAM while training.
Step 3: Use GRADCamLogger and train
Given we're working with a simple dataset I have only trained for few epochs and the model seems to work well.
Here's the GradCAM custom callback.
class GRADCamLogger(tf.keras.callbacks.Callback):def __init__(self, validation_data, layer_name):super(GRADCamLogger, self).__init__()self.validation_data = validation_dataself.layer_name = layer_namedef on_epoch_end(self, logs, epoch):images = []grad_cam = []## Initialize GRADCam Classcam = GradCAM(model, self.layer_name)for image in self.validation_data:image = np.expand_dims(image, 0)pred = model.predict(image)classIDx = np.argmax(pred[0])## Compute Heatmapheatmap = cam.compute_heatmap(image, classIDx)image = image.reshape(image.shape[1:]) image = image*255image = image.astype(np.uint8)## Overlay heatmap on original imageheatmap = cv2.resize(heatmap, (image.shape[0],image.shape[1]))(heatmap, output) = cam.overlay_heatmap(heatmap, image, alpha=0.5)images.append(image)grad_cam.append(output)wandb.log({"images": [wandb.Image(image)for image in images]})wandb.log({"gradcam": [wandb.Image(cam)for cam in grad_cam]})
Run set
1
Conclusion
Class Activation Maps and Grad-CAMs are a few approaches that introduce some explainability/interpretability into deep learning models and are quite widely used. What's most fascinating about these techniques is the ability to perform the object localization task, even without training the model with a location prior. GradCAM, when used for image captioning, can help us understand what region in the image is used to generate a certain word. When used for a Visual Q&A task, it can help us understand why the model came to a particular answer. Even though Grad-CAM is class-discriminative and localizes the relevant image regions, it lacks the ability to highlight fine-grained details the way pixel-space gradient visualization methods like Guided backpropagation, and Deconvolution do. Thus the authors combined Grad-CAM with Guided backpropagation.
Thanks for reading this report until the end. I hope you find the callbacks introduced helpful for your deep learning wizardry. Please feel free to reach out to me on Twitter(@ayushthakur0) for any feedback on this report. Thank you.
Try it out in a Google Colab →
Add a comment
Thank you for this article Ayush. The clear difference between CAM and gradCAM and the motivation behind moving from one to another is really helpful.
I am currently working with NIH chest X-ray dataset and was looking for ways to interpret my classification models. I discovered this article at the right time.
I will be using the custom callback off-the-shelf to interpret my models.
Thank you. :)
1 reply
Tags: Intermediate, Computer Vision, Object Detection, Keras, Experiment, CAM, Panels, Plots, Slider
Iterate on AI agents and models faster. Try Weights & Biases today.