Wave a wandb.plot() to Visualize

How to visualize classification models in a few lines with the W&B Python API. Made by Stacey Svetlichnaya using Weights & Biases
Stacey Svetlichnaya

Easy visualizations for classification models

This report covers some of the built-in visualizations in wandb.plot() which you can call in one line on classification models.

Task: Identify 10 types of living things in photos with a simple convnet

The data consists of 10,000 training images and 2,000 validation images from the iNaturalist dataset, evenly distributed across 10 classes of living things like birds, insects, plants, and mammals (names given in Latin—so Aves, Insecta, Plantae, etc :). We will fine-tune a convolutional neural network already trained on ImageNet on this task: given a photo of a living thing, correctly classify it into one of the 10 classes. Note that in many experiments, I run toy version of the training using a very small subset of the data (100-1000 train, 500 val) and a small number of epochs (3-5) for illustration purposes.

Chart interactivity

In all of the preset custom charts, you can:

Precision-Recall and ROC curves

Log a Precision-Recall curve or an ROC curve (see links for more details):

wandb.plot.pr_curve(ground_truth, predictions, labels=None)
wandb.plot.roc_curve(ground_truth, predictions, labels=None)

You can log this whenever your training/testing code has access to:

This lets you disentangle a classification model's performance on different classes. The toy models below perform best on birds (teal) and mushrooms (green): these curves have the largest area under the precision-recall curve and are closest to the top left corner of the ROC curve (highest true positive rate and lowest false positive rate). The models seem to struggle the most with reptiles (gray) and amphibians (blue)—low area under PR curve and higher false positive rate for the same true positive rate). Note that the precision recall curve subsamples heavily (20 points per curve).

Section 2

Custom line plot: Average precision across runs

Log an average precision curve to compare across different versions of your model with the same arguments as above. For now this requires two lines of Python in addition to wandb.plot.line() (details in the Custom Line Plot report).

Instead of separating performance by class, this approach groups them together and focuses on overall binary correctness (micro-averaged precision). Below, average precision predictably increases as we increase the number of training samples. However, at 100 samples, the training is very noisy (teal vs green curve) over multiple runs of the same code, and even at 1000 vs 2000 samples (violet vs pink curve), the performance is very similar.

Section 3

Custom scatter plot: Explore correlations

In the most general case, you can plot any arbitrary metric Y versus an arbitrary metric X. Assume these metrics are a series of numbers that have some correspondence and can be expressed as a list of (x, y) points: (x0, y0), (x1, y1), (x2, y2)... Once you have the metrics available as two lists x_values and y_values, you can plot y vs x by calling the code below (details in Custom Scatter Plots).

data = [[x, y] for (x, y) in zip(x_values, y_values)]
table = wandb.Table(data=data, columns = ["x", "y"])
wandb.log({"my_custom_plot_id" : wandb.plots.scatter(table, "x", "y", title="Custom Y vs X Scatter Plot")

Here I plot the confidence scores for two different labels on the same set of validation samples. You can see that the model is more likely to confuse reptiles and amphibians at low confidence (yellow points, many on a diagonal), slightly less likely to confuse insects and arachnids (blue points, fewer in the middle of the plot, more concentrated in bottom left), and unlikely to confuse birds for mushrooms (magenta points, overall low confidence for label=bird and a range of confidence for label=mushroom). These are of course generalizations from a single visualization, but they can help build intuition and reveal where the model is likely to confuse classes.

Section 5

Class histograms: Confidence scores

Log a custom histogram—sort list of values into bins by count/frequency of occurrence—natively in a few lines with Custom Histograms. Let's say I have a list of prediction confidence scores (scores) and want to visualize their distribution:

data = [[s] for s in scores]
table = wandb.Table(data=data, columns=["scores"])
wandb.log({'my_histogram': wandb.plot.histogram(table, "scores")})

I vary NT, the number of training examples and E, the number of training epochs for each run. Both numbers are tiny for illustration purposes. When these are too small, a toy model only outputs very low confidence scores (<0.1). With increasing epochs and numbers of training examples, we start to see more high confidence scores and some intermediate scores for the model's prediction confidence (in this case, identifying the living thing in an image as a bird as opposed to one of the other nine classes: plant, mammal, etc).

Section 8

Custom bar charts: Per-class metrics like precision

Log a custom bar chart—a list of labeled values as bars—natively in a few lines:

data = [[label, val] for (label, val) in zip(labels, values)]
table = wandb.Table(data=data, columns = ["label", "value"])
wandb.log({"my_bar_chart_id" : wandb.plot.bar(table, "label", "value", title="Custom Bar Chart")

You can use this to log arbitrary bar charts. Note that the number of labels and values in the lists must match exactly (i.e. each data point must have both). In the example below, you can display/hide individual runs from the bar chart by clicking on the "eye" icon to the left of each run name under "Toy CNN variants".

I vary NT, the number of training examples and E, the number of training epochs for each run. Both numbers are tiny for illustration purposes. With increasing epochs and numbers of training examples, the precision increases for all classes. Anecdotally, birds (Aves) are the least visually diverse and animals (Animalia) are probably the most visually diverse classes in the dataset, which matches the fast improvement of the former and lagging performance of the latter.

Section 9

Create your own!

You can create your own preset and log data to it directly from a script. Click on the gear icon in the top right to see a slider that adjusts the bin size, giving a slightly different perspective on the data.

This chart shows confidence scores of an image classification CNN for two different labels on the same images, specifically for identifying the creature in each image as an amphibian versus a reptile. This toy model is a bit more certain about amphibians (narrower, higher peak) than reptiles (broader range of scores).

Section 11