Interpret any PyTorch Model Using W&B Embedding Projector
An introduction to our embedding projector with the help of some furry friends
Created on January 13|Last edited on May 12
Comment
The Question
Let's say you're working on the Oxford-IIIT Pet Dataset to classify pet breeds, and you want to know which species of pets look alike and which look different. How would you go about solving this problem?
Before we dig in too deep, let's start by looking at the dataset. Chances are, even if you've never seen it, you have a pretty good idea what to expect: lots and lots of pets.
There are 37 pet breeds in the dataset and the table above shows the first 30-40 images for each pet breed.
Now, coming back to that question! How could you possibly know which species of pets look alike and which look different? It would have to be through some sort of grouping, right? But how exactly?
At this point, I really want you to take a minute and think about the solution before proceeding any further. :)
💡
An Answer
A clever way to achieve this would be to train a deep learning model on the pets dataset, and then use this trained model to learn more about the dataset. That's right: you could use this trained model to learn more about the data itself.
"Ah, but how?" you ask.
Well, you could pass images from all 37 classes to the model, and extract image embeddings from this trained model, then use a dimensionality reduction technique such as PCA, T-SNE or UMAP to plot these image embeddings on a chart!
This would group similar looking pets together and different looking pets further away. You can see that on the chart as shown below. In other words: categories that are closer to each other in the scatter plot below, must be similar looking.
As an example, two of the breeds that are very close to each other in the plot above are "British Shorthair" and "Russian Blue". Let's see if these breeds look alike:
Okay! So far, so good. Our technique looks promising.
Let's look at another example though. Here, let's compare "Yorkshire Terrier" and "Keeshond." These are again quite close to each other in the scatter plot above and, as you can see below, the two breeds have some marked similarities.
Run: wobbly-wood-5
1
Those good boys seem pretty related to me.
So what exactly is going on? How did we manage to create this scatter plot using image embeddings? Why do similar looking pets appear closer to each in the scatter plot above and different looking pets appear further away? In the next section, we explain everything in detail.
The Method
What are Image Embeddings?
When we train a deep learning model on a given dataset, the model learns some patterns about the dataset itself. For example for the pets dataset, the model would know what furs look like, or what eyes look like etc. Then, we could pass new pet images to the trained model to get vector outputs from the penultimate layer of the model. These vector representations are referred to as Image Embeddings.
How to get these Image Embeddings outputs for any PyTorch model?
If you didn't know about PyTorch Hooks before, PyTorch 101, Part 5: Understanding Hooks is a great place to get started! (And for the record: in this section I'm just going to assume that the reader knows about PyTorch hooks.)
We could register a forward hook to the penultimate layer of any PyTorch model to get Image Embedding outputs as long as you know the layer name.
Let's say we want to get the output of the Average Pooling layer of ResNet-34 Architecture. In code:
import torchimport torch.nn as nnimport torchvisionclass FeatureExtractor(nn.Module):def __init__(self, model, layer_names):super().__init__()self.model = modelself.layer_names = layer_namesself._features = defaultdict(list)layer_dict = dict([*self.model.named_modules()])for layer_name in layer_names:layer = layer_dict[layer_name]layer.register_forward_hook(self.save_outputs_hook(layer_name))def save_outputs_hook(self, layer_name):def fn(_, __, output):self._features[layer_name] = outputreturn fndef forward(self, x):_ = self.model(x)return self._featuresx = torch.randn(1,3,224,224)model = torchvision.models.resnet34()fx = FeatureExtractor(model, ['avgpool'])fx(x)['avgpool'].shape>> torch.Size([1, 512, 1, 1])
The FeatureExtractor class above can be used to register a forward hook to any module inside the PyTorch model. Given some layer_names, the FeatureExtractor registers a forward hook save_outputs_hook for each of these layer names. As per PyTorch docs, the hook will be called every time after forward() has computed an output.
A PyTorch hook has the following signature:
hook(module, input, output) -> None or modified output
Therefore, the save_outputs_hook has a decorator function fn that stores the outputs inside self._features dictionary!
The class FeatureExtractor as can be seen from the forward method returns the self._features dictionary as the output.
In our small example of resnet34, we simply register a forward hook to the avgpool layer and get the output shape. Try this class out for yourself in this colab!
How to get started with W&B Embedding Projector Plot
import wandbwandb.init(project="embedding_tutorial")embeddings = [[0.2, 0.4, 0.1, 0.7, 0.5], # embedding 1[0.3, 0.1, 0.9, 0.2, 0.7], # embedding 2[0.4, 0.5, 0.2, 0.2, 0.1], # embedding 3]wandb.log({"embeddings": wandb.Table(columns = ["D1", "D2", "D3", "D4", "D5"],data = embeddings)})wandb.finish()
After running the above code, your W&B dashboard will have a new Table containing your data. You can select 2D Projection from the upper right panel selector to plot the embeddings in 2 dimensions.

Fig-1: Get started with W&B Embedding Projector
How to create the W&B Embedding Projector for the Pets dataset
To create all tables and figures showcased in this blog post, you could follow along this colab notebook.
As you can see in the notebook, once we have a trained model, we simply just create a wandb Table with embedding outputs and labels and log this table to W&B. In code, this looks like as below:
cols = [f"out_{i}" for i in range(features.shape[1])]# create pandas dataframe from feature outputs of shape (1478, 512) and add labelsdf = pd.DataFrame(features, columns=cols)df['LABEL'] = labels# log pandas DataFrame to W&B easilytable = wandb.Table(columns=df.columns.to_list(), data=df.values)wandb.init(project="embedding_projector")wandb.log({"Pet Breeds": table})wandb.finish()
Running the above code logs this W&B Table. Next, as in the previous section, we go to settings of the table and select "2D Projection Plot" to get our Embeddings Plot! That's really all it takes.
About the Embedding Projector Plot
Now that we've logged the table to Weights and Biases (as shown in the docs), we could create the embedding projector plot using multiple dimensionality reduction algorithms such as PCA, T-SNE or UMAP.

Fig-2: Parameters to create our Embedding Plot
Interpreting the Pet Embedding Projector Plot
Since the embedding projector plot simply first logs the image embeddings and then uses a dimensionality reduction technique to plot points in 2-D space, the points that appear close to each other have similar image embeddings.
That means that the model has created very similar vector representations for images that appear close on the projection plot. In turn, that means that the two PET images ought to be similar!
Also, similarly, points that appear further away from each other on the embedding plot refer to the PET images that must be different from each other.
Conclusion
As part of this report, I hope that I have been able to showcase how you could also plot image embeddings to get a better understanding for your dataset.
As an example, we looked at the pets dataset, trained a model using fastai, and then used W&B to create embedding projector plot.
We saw examples of pet categories that appear close to each other on the plot - "British Shorthair" and "Russian Blue".
As part of this report, I have also shared a colab notebook with working code for you to be able to replicate all tables and figures in this report. Thanks for reading!
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.