Skip to main content

Creating Custom Charts From Scratch With Weights & Biases

In this article, we explore how to build a multi-class confusion matrix in Weights & Biases using Vega, fine-tuning a CNN to predict one of 10 classes of living things.
Created on October 27|Last edited on November 16

Log a Multi-Class Confusion Matrix to W&B

To create a multi-class confusion matrix in W&B, first, find a place where your model development code has access to predicted labels and corresponding ground truth for the same set of examples (typically in a validation step). Then simply:
  • Pass these to a plot_confusion_matrix() Python function (currently provided as a standalone wrapper, soon to be added to the wandb API)
  • Create a custom chart with the Confusion Matrix v0 Vega spec (shown below, soon to be added as a preset under wandb.plot)
  • Connect the right fields in the dropdown menu via the query editor and view a chart like the one below & customize further if you wish!
In this toy example, I finetune a CNN to predict one of 10 classes of living things (plants, animals, insects) while varying the number of epochs (E or pretrain_epochs) and the number of training examples (NT or num_train). In the run set below, you can toggle the eye symbol next to each run to show/hide it. You can see the relative performance of each model at a glance, and hover over the different bars to see the exact count.
Unsurprisingly, models with too few examples/epochs tend to make more mistakes (row of blue across "Aves" and "Reptilia" for the smallest model, row of red across "Animalia" for the second smallest). As the number of epochs and examples increases, the model tends to make more accurate predictions (stronger along the diagonal). Amphibians vs reptiles are some of the more commonly confused classes across these (admittedly noisy) models.

Vary num train and num epochs
9



Step 1: Add Python Code To Plot a wandb.Table()

In my validation step, I have access to val_data and the corresponding val_labels for all my validation examples, as well as my full list of possible labels for the model: labels=["Amphibia", "Animalia", ... "Reptilia"], which means an integer class label of 0 = Amphibia, 1 = Animalia, ... 9 = Reptilia). Referencing the model I have trained so far in my validation callback, I call:
val_predictions = model.predict(val_data)
ground_truth = val_labels.argmax(axis=1)
plot_confusion_matrix(ground_truth, val_predictions, labels)
where plot_confusion_matrix() is defined below. You can further customize the call with:
  • true_labels or pred_labels to narrow down the subset of classes you'd like to display as rows or columns in the matrix
  • normalize flag to show normalized counts (floats up to a max of 1.0) instead of raw counts
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(y_true=None, y_pred=None, labels=None, true_labels=None,
pred_labels=None, normalize=False):
"""
Computes the confusion matrix to evaluate the accuracy of a classification.
"""
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
cm = confusion_matrix(y_true, y_pred)
if labels is None:
classes = unique_labels(y_true, y_pred)
else:
classes = np.asarray(labels)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm = np.around(cm, decimals=2)
cm[np.isnan(cm)] = 0.0
if true_labels is None:
true_classes = classes
else:
true_label_indexes = np.in1d(classes, true_labels)
true_classes = classes[true_label_indexes]
cm = cm[true_label_indexes]
if pred_labels is None:
pred_classes = classes
else:
pred_label_indexes = np.in1d(classes, pred_labels)
pred_classes = classes[pred_label_indexes]
cm = cm[:, pred_label_indexes]
data=[]
count = 0
for j, i in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
if labels is not None and (isinstance(pred_classes[i], int)
or isinstance(pred_classes[0], np.integer)):
pred_dict = labels[pred_classes[i]]
true_dict = labels[true_classes[j]]
else:
pred_dict = pred_classes[i]
true_dict = true_classes[j]
data.append([pred_dict, true_dict, cm[i,j]])
count+=1
wandb.log({"confusion_matrix" : wandb.Table(
columns=['Predicted', 'Actual', 'Count'],
data=data)}


Step 2: Create a custom chart for the confusion matrix

W&B custom charts are written in Vega, a powerful and flexible visualization language. You can find many examples and walkthroughs online, and it can help to start with an existing preset that is most similar to your desired custom visualization. You can iterate from small changes in our IDE, which renders the plot as you change its definition.
Here is the full Vega spec for this multi-class confusion matrix:
  • From your project workspace or report, click on "Add a visualization" and select "Custom chart"
  • Pick any existing preset and replace its definition with the Vega spec below
  • Click "Save As" to give this preset a name for easier reference (I recommend "confusion_matrix" :)
{
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
"description": "Multi-class confusion matrix",
"data": {
"name": "wandb"
},
"width": 40,
"height": {"step":6},
"spacing": 5,
"mark" : "bar",
"encoding": {
"y": {"field": "name", "type": "nominal", "axis" : {"labels" : false},
"title" : null, "scale": {"zero": false}},
"x": {
"field": "${field:count}",
"type": "quantitative",
"axis" : null,
"title" : null
},
"tooltip": [
{"field": "${field:count}", "type": "quantitative", "title" : "Count"},
{"field": "name", "type": "nominal", "title" : "Run name"}
],
"color": {
"field": "name",
"type": "nominal",
"legend": {"orient": "top", "titleOrient": "left"},
"title": "Run name"
},
"row": {"field": "${field:actual}", "title": "Actual",
"header": {"labelAlign" : "left", "labelAngle": 0}},
"column": {"field": "${field:predicted}", "title": "Predicted"}
}
}

Step 3: Map Relevant Data Fields From Logged Runs Into Your Chart

On the right-hand side of the visualization IDE, modify the run query to feed run data into the confusion matrix:
  • change summary to summaryTable
  • enter your custom table id as the first entry in tableKeys—this is the key to which you logged the wandb.Table, in my case it's confusion_matrix
  • use the dropdown menu of the query editor to connect matching fields so that values logged to the "Count" column of the wandb.Table are read into the "count" field of the Vega chart, "Actual" column to the "actual" field, etc. The final query should look something like this:


Run set
0



Customize As You Wish

By editing the Vega spec, you can adjust the height, width, and color scheme of the chart. For example, this chart uses "scale" : {"scheme" : "rainbow"} in color to recolor the runs).
Colors closer to blue here correspond to more training examples/more epochs, and these generally show stronger performance along the diagonal. Mollusks classified as just "animals" and Amphibians classified as Reptiles are the two most common mistakes of the largest model (train on 10,000 examples for 10 epochs). It's interesting to see the green "NT 1000, E 10" model outperform this largest blue "NT 10000, E 10" model in certain cells, even with 10 times less data.


Vary num train and num epochs
6

Iterate on AI agents and models faster. Try Weights & Biases today.