Custom Charts from Scratch

Build a confusion matrix using Vega. Made by Stacey Svetlichnaya using Weights & Biases
Stacey Svetlichnaya

Log a multi-class confusion matrix to W&B

To create a multi-class confusion matrix in W&B, first find a place where your model development code has access to predicted labels and corresponding ground truth for the same set of examples (typically in a validation step). Then simply

In this toy example, I finetune a CNN to predict one of 10 classes of living things (plants, animals, insects) while varying the number of epochs (E or pretrain_epochs) and the number of training examples (NT or num_train). In the run set below, you can toggle the eye symbol next to each run to show/hide it. You can see the relative performance of each model at a glance, and hover over the different bars to see the exact count.

Unsurprisingly, models with too few examples/epochs tend to make more mistakes (row of blue across "Aves" and "Reptilia" for the smallest model, row of red across "Animalia" for the second smallest). As the number of epochs and examples increases, the model tends to make more accurate predictions (stronger along the diagonal). Amphibians vs reptiles are some of the more commonly confused classes across these (admittedly noisy) models.

Section 2

Step 1: Add Python code to plot a wandb.Table()

In my validation step, I have access to val_data and the corresponding val_labels for all my validation examples, as well as my full list of possible labels for the model: labels=["Amphibia", "Animalia", ... "Reptilia"], which means an integer class label of 0 = Amphibia, 1 = Animalia, ... 9 = Reptilia). Referencing the model I have trained so far in my validation callback, I call:

val_predictions = model.predict(val_data)
ground_truth = val_labels.argmax(axis=1)
plot_confusion_matrix(ground_truth, val_predictions, labels)

where plot_confusion_matrix() is defined below. You can further customize the call with:

from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(y_true=None, y_pred=None, labels=None, true_labels=None,
                          pred_labels=None, normalize=False):
    """                   
    Computes the confusion matrix to evaluate the accuracy of a classification.
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    cm = confusion_matrix(y_true, y_pred)
    if labels is None:
        classes = unique_labels(y_true, y_pred)
    else:
        classes = np.asarray(labels)
            
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm = np.around(cm, decimals=2)
        cm[np.isnan(cm)] = 0.0
            
    if true_labels is None:
        true_classes = classes
    else:
        true_label_indexes = np.in1d(classes, true_labels)
        true_classes = classes[true_label_indexes]
        cm = cm[true_label_indexes]
            
    if pred_labels is None:
        pred_classes = classes
    else:
        pred_label_indexes = np.in1d(classes, pred_labels)
        pred_classes = classes[pred_label_indexes]
        cm = cm[:, pred_label_indexes]
            
    data=[]
    count = 0
    for j, i in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if labels is not None and (isinstance(pred_classes[i], int)
                                    or isinstance(pred_classes[0], np.integer)):
            pred_dict = labels[pred_classes[i]]
            true_dict = labels[true_classes[j]]
        else:
            pred_dict = pred_classes[i]
            true_dict = true_classes[j]
        data.append([pred_dict, true_dict, cm[i,j]])
        count+=1
    wandb.log({"confusion_matrix" : wandb.Table(
                columns=['Predicted', 'Actual', 'Count'],
                data=data)}

Step 2: Create a custom chart for the confusion matrix

W&B custom charts are written in Vega, a powerful and flexible visualization language. You can find many examples and walkthroughs online, and it can help to start with an existing preset that is most similar to your desired custom visualization. You can iterate from small changes in our IDE, which renders the plot as you change its definition. Here is the full Vega spec for this multi-class confusion matrix.

{
  "$schema": "https://vega.github.io/schema/vega-lite/v4.json",
  "description": "Multi-class confusion matrix",
  "data": {
    "name": "wandb"
  },
  "width": 40,
  "height": {"step":6},
  "spacing": 5,
  "mark" : "bar",
  "encoding": {
    "y": {"field": "name", "type": "nominal", "axis" : {"labels" : false}, 
         "title" : null, "scale": {"zero": false}},
    "x": {
      "field": "${field:count}",
      "type": "quantitative",
      "axis" : null,
      "title" : null
    },
    "tooltip": [
      {"field": "${field:count}", "type": "quantitative", "title" : "Count"},
      {"field": "name", "type": "nominal", "title" : "Run name"}
    ],
    "color": {
      "field": "name",
      "type": "nominal",
      "legend": {"orient": "top", "titleOrient": "left"},
      "title": "Run name"
    },
    "row": {"field": "${field:actual}", "title": "Actual", 
            "header": {"labelAlign" : "left", "labelAngle": 0}},
    "column": {"field": "${field:predicted}", "title": "Predicted"}
  }
}

Step 3: Map relevant data fields from logged runs into your chart

On the right hand side of the visualization IDE, modify the run query to feed run data into the confusion matrix:

Section 6

Customize as you wish

By editing the Vega spec, you can adjust the height, width, and color scheme of the chart. For example, this chart uses "scale" : {"scheme" : "rainbow"} in color to recolor the runs).

Colors closer to blue here correspond to more training examples/more epochs, and these generally show stronger performance along the diagonal. Mollusks classified as just "animals" and Amphibians classified as Reptiles are the two most common mistakes of the largest model (train on 10,000 examples for 10 epochs). It's interesting to see the green "NT 1000, E 10" model outperform this largest blue "NT 10000, E 10" model in certain cells, even with 10 times less data.

Section 2