Confusion Matrix

Usage and examples for a multi-class confusion matrix. Made by Stacey Svetlichnaya using Weights & Biases
Stacey Svetlichnaya

Method: wandb.plot.confusion_matrix()

Log a multi-class confusion matrix in one line:
wandb.log({"conf_mat" : wandb.plot.confusion_matrix(probs=None, y_true=ground_truth, preds=predictions, class_names=class_names)})
You can log this wherever your code has access to:

Try it yourself via Colab →

Basic usage

In this toy example, I finetune a CNN to predict one of 10 classes of living things in a photo (plants, animals, insects) while varying the number of epochs (E or pretrain_epochs) and the number of training examples (NT or num_train). I log a confusion matrix in each validation step, after each training epoch (so only the last confusion matrix at the end of training is visualized).

Powerful interactions

In this confusion matrix chart, you can

Observations on this matrix

Runs colored closer to blue/violet above correspond to more training examples/more epochs, and these generally show stronger performance along the diagonal, compared to runs colored closer to red. Mollusks classified as just "animals" and Amphibians classified as Reptiles are the two most common mistakes of the largest model (train on 10,000 examples for 10 epochs). It's interesting to see the blue "NT 1000, E 10" model outperform this largest violet "NT 10000, E 10" model in several diagonal cells, even with 10 times less data—perhaps due to overfitting.

Logging details

In my validation step, I have access to val_data and the corresponding val_labels for all my validation examples, as well as my full list of possible labels for the model: all_labels=["Amphibia", "Animalia", ... "Reptilia"], which means an integer class label of 0 = Amphibia, 1 = Animalia, ... 9 = Reptilia). Referencing the model I have trained so far, I call the following in my validation callback.
val_predictions = model.predict(val_data)top_pred_ids = val_predictions.argmax(axis=1)ground_truth_ids = val_labels.argmax(axis=1)wandb.log({"my_conf_mat_id" : wandb.plot.confusion_matrix( preds=top_pred_ids, y_true=ground_truth_ids, class_names=all_labels)})
This creates a confusion matrix and logs it to the "Custom Charts" section of my Workspace, under the specified key my_conf_mat_id. Keep this key fixed to display multiple runs on the same confusion matrix.
Note: I explicitly take the argmax of the prediction scores to return the class ids of the top predictions (highest confidence score) across the images: one per image. While this is the most common scenario for a confusion matrix, the W&B implementation allows for other ways of computing the relevant prediction class id to log. For example, you could use an embedding or distance function to find the most likely class, or you could easily account for top-N accuracy across many classes. You could also log precomputed probabilities via the probs argument, making sure these have the shape (number of examples, number of classes).
See the API definition for wandb.plot.confusion_matrix()→

Customize as you wish

By editing the Vega spec, you can adjust various aspects of the chart using the Vega visualization grammar. For example, below I try two different color palettes for the model variants, and focus the confusion matrix on different subsets of classes using the gear pop-up menu in the top right corner. You can try this yourself in the last chart of this report by

Customization details

Change the color scheme

Zoom in on a subset of classes

Save your changes

One last comparison

This toy model tends to over-predict plants and insects even when training on the full dataset. You can try your own experiments with a custom confusion matrix in this Colab. Please ask any questions & let me know how it goes in the comments below!