Visualizing Confusion Matrices With W&B

Using Keras with Weights & Biases, plot a confusion matrix at every step of model training and see where your algorithm is wrong.
Mathïs Fédérico

Introduction

For classification problems, we often rely on scalar metrics, but they don't capture the full picture. Accuracy, Precision, Recall, F1, AUC, they all lack something:

The confusion matrix does a good job of summarizing the performance of your classifier.

Experiment with this Google Colab Notebook

Visualizing the Confusion Matrix

Confusion matrix

The confusion matrix is a 2D histogram that compares predictions and ground truth. For example, if this model was given a picture of the number 0 and it predicted 2, the cell at row 2 and column 0 would get a +1. You can scroll the wheel over the various cells above to see how many instances have been misclassified in each bin.

Every time the model predicts the correct value, it gets a +1 in the diagonal, at row 0 column 0, row 1 column 1, etc. A perfect model would have all the examples lie on the green diagonal, with no red squares for misclassified examples.

The model begins quite incorrectly and gradually improves. Click the gear icon in the upper left corner of the confusion matrix panel to get a step slider, and slide across the training to see how the model learns to classify the handwritten digits more and more accurately.

Take the step slider all the way to 0. You can see that almost all the predictions were class 3. There's a big red line for all the other numbers that the model incorrectly guessed were a 3, and one green square where it correctly labeled images of 3s.

Examine Misclassified Examples

Now that we've got a well-trained model, our confusion matrix looks pretty good, but there are still some data points that are misclassified.

To get a more in-depth view of misclassified images, explore confusion examples. Looking at the darkest red squares off-diagonal in the confusion matrix, the most common incorrectly identified.:

Below, see each of the misclassifications in a row. There are images on the left marked red that the model labeled incorrectly. Right in the green are the ground truths — an example of the class that the model did correct, and an example of the class that it thought it saw.

Examine misclassified examples

Take a look at the third row. There are 3 examples of very loopy 9s that look a lot like 4s, even to the human eye.

It makes sense that the model got confused here!

Try it Yourself

Experiment with this Google Colab Notebook

1. Install the Libraries

2. Extend the WandbCallback

We already have a good callback example with the build-in WandbCallback object, so we extend it like so (just copy and paste).

3. Put the Custom Callback in the fit Function

Add the validation data and labels (or generator), and set the input_type to image to get confusing examples.

Parameters:

About the Author

image.png Mathïs Fédérico is a machine learning researcher based in France and founder of Automatants the AI group of CentraleSupelec engineering school. ​​Currently, he is doing a master's in the field of computer science.He's the creator of the library LearnRL which provides an accessible way to get started with reinforcement learning.

Profiles