Visualise Failure - Debugging with Model Activations

Visualising your model's activations can help you spot if your model is training sub-optimally, even if the model appears to train "successfully". .
Morgan
Here I visualise how the activations generated while training a transformer model (GPT) vary with a variety of different training modes; the Weights & Biases workspace for this project can be found here. This work is based on the brillant activations visualisation work by Stefano Giomo in the fast.ai forums and described in the fast.ai course.
Code used to create the Baseline run of this report can be found in this colab.

Language Modelling with minGPT & Shakespeare

I used the minGPT repo from Andrej Karpathy, which is based on the original GPT architecture, the granddaddy of GPT-3. It's a clean and simple codebase and allows for useful testing. The language model was trained using the fast.ai library on the "Tiny Shakespeare" dataset.

Layers, Layers & More Layers - Where to Look?

When looking at charts of activations data it can be hard to know which layer(s)' activations to look at; should I look early in the model? If there isn't enough going on in the early stages, should I look late in the model? Well, maybe the interesting stuff is actually happening in the middle..."
I found that for this architecture, the activations from the FeedForward layers of the transformer blocks and from the the final Linear layer seemed to be the most informative, as they changed the most as the baseline model learned. They were also the layers which changed most dramatically compared to the baseline when deliberately poor hyperparameters were used for training.
For this example, I visualised activations from the embedding layer, final Attention Layer, and final Linear layer of the model.
Tip: When visualising the activations, try a training run with deliberately poor settings (e.g. very high or low learning rate) and see which layers change the most. These could be good layers to concentrate on for future visualisations.

Setup

I tested 5 different training modes to examine how activations differed under different scenarios:
I used fast.ai's ActivationStats hook to grab the activations from my chosen layers and generate the histogram data used for the visualisations. These plots can also be obtained by running learn.activation_stats.color_dim(layer_index) in your notebook when using the ActivationStats callback.
Training was carried out for 50 epochs at a batch size of 256 with the AdamW optimizer. Each epoch took about 20-30 seconds depending on whether activations were being logged or not. One-Cycle training was used, with the peak learning rate occurring at the default 25% of the total training steps.
minGPT Initialisation vs Pytorch Defaults
A comparison of the minGPT initialisations vs Pytorch default initialisations is below:

Training Loss

Only training loss was monitored for this toy example.

Activations Charts

Now to take a peak at what the activations behind these loss curves can actually tell us! By visualising the embedding activations you can see the impact of different training settings. We'll look at 2 visualisations; "Activations near Zero" and an activations "histograms over time" (HOT) chart.

Activations Near Zero

An "Activations near Zero" plot displays the percentage of activations whose value is near zero (0 ± 0.05). If we see strange behaviour here it can be a prompt to further investigate.
Tip: Looking at the mean and standard deviation of your activations can also help and fast.ai's "ActivationStats" callback can also do this, example here.

Activations Histograms Over Time

In an activations HOT chart, the x-axis represents the time steps, from the first step to the last. Every interval along the x-axis is a histogram of activation values at that step. Histogram bins are normally displayed along the x-axis, but here the bins (60 in this example) are displayed along the y-axis. The activations values between -10 and 10 are binned and then the log of the count plus 1 (to deal with count values of zero) of activations in each bin is taken; log(1 + count). We use the log of the count because it provides a stronger visual contrast, as discussed by the folks in the fast.ai forums. The brighter the color of a bin, the higher the count of activations in that bin.
Activations HOT charts are useful for inspecting the distribution of activations from a layer over time. Inspecting these distributions can inform us if the model is performing as expected during training.

Visualisation Results

Embedding Layer - Activations Near Zero

Generally we don't expect the magnitude of text embeddings to vary radically in a training like this, as the embeddings matrix is quite sparse, with only a a few vectors being updated in each training step.

Embedding Layer - Activations HOT

In this case visualising the activations HOT can show us if our model is initialised as expected. Since we expect most activations to be close to zero, the broad distribution of activation values in the second chart in the top row (No Custom Initialisation) could be a warning flag.

Final Attention Layer - Activations Near Zero

The activations near zero for the output of the final Attention Block were also logged.

Final Attention Layer - Activations HOT

I didn't create a chart for the output of the final Attention Layer because it's quite visually uninteresting, with most activations sitting quite close to zero, similar to the outputs from the embedding layer. Activations Near Zero is the more informative chart for this layer.

Final Linear Layer - Activations Near Zero

We also don't see too much interesting happening in the Activations Near Zero of the final layer of the model. For this layer, we might want to either change the threshold value of what we consider to be "near zero", or look at other statistics such as the mean or standard deviation.

Final Linear Layer - Activations HOT

Looking at the final Linear layer of the network, we have a much more interesting picture.

Summary

Visualising your activations can provide potentially valuable insight into how your model training is going and whether its architectural settings or training hyperparameters might benefit from a few tweaks.
I would encourage anyone interested to further explore this form of debugging. I think there's an opportunity here to uncover some valuable debugging heuristics for ML practitioners!