For classification problems, we often rely on scalar metrics, but they don't capture the full picture. Accuracy, Precision, Recall, F1, AUC, they all lack something:

- Accuracy is not relevant when classes are imbalanced
- Precision/Recall work for binary classification, but you have to manually set a decision threshold
- F1/AUC are commonly used in competitions but can be influenced by class skew

The **confusion matrix** does a good job of summarizing the performance of your classifier.

The confusion matrix is a 2D histogram that compares predictions and ground truth. For example, if this model was given a picture of the number **0** and it predicted 2, the cell at row 2 and column 0 would get a +1. You can scroll the wheel over the various cells above to see how many instances have been misclassified in each bin.

Every time the model predicts the correct value, it gets a +1 in the diagonal, at row 0 column 0, row 1 column 1, etc. A perfect model would have all the examples lie on the green diagonal, with no red squares for misclassified examples.

The model begins quite incorrectly and gradually improves. Click the gear icon in the upper left corner of the confusion matrix panel to get a step slider, and slide across the training to see how the model learns to classify the handwritten digits more and more accurately.

Take the step slider all the way to 0. You can see that almost all the predictions were class 3. There's a big red line for all the other numbers that the model incorrectly guessed were a 3, and one green square where it correctly labeled images of 3s.

Now that we've got a well-trained model, our confusion matrix looks pretty good, but there are still some data points that are misclassified.

To get a more in-depth view of misclassified images, explore **confusion examples**. Looking at the darkest red squares off-diagonal in the confusion matrix, the most common incorrectly identified.:

- 3 instead of 5
- 9 instead of 4
- 4 instead of 9
- 0 instead of 6
- 0 instead of 8

Below, see each of the misclassifications in a row. There are images on the left marked red that the model labeled incorrectly. Right in the green are the ground truths — an example of the class that the model did correct, and an example of the class that it thought it saw.

Take a look at the third row. There are 3 examples of very loopy 9s that look a lot like 4s, even to the human eye.

It makes sense that the model got confused here!

- Wandb
- TensorFlow
- Keras
- Pandas
- Plotly
- Scikit-learn

We already have a good callback example with the build-in `WandbCallback`

object, so we extend it like so
(just copy and paste).

Add the validation data and labels (or generator), and set the input_type to **image** to get confusing examples.

Parameters:

- log_confusion_matrix (bool) : If true, log the confusion matrix at each epochs.
- confusion_classes (int) : the number of worst confusion classes to show.
- confusion_examples (int) : the number of confusion examples to show for each confusion class. (0 is disabled)

Mathïs Fédérico is a machine learning researcher based in France and founder of Automatants the AI group of CentraleSupelec engineering school. Currently, he is doing a master's in the field of computer science.He's the creator of the library LearnRL which provides an accessible way to get started with reinforcement learning.

**Profiles**