Skip to main content

NEW URL WIP with Private Run Set

How to visualize classification models in a few lines with the W&B Python API
Created on October 16|Last edited on October 16

Easy one-line visualizations for classification models

This report covers some of the built-in visualizations in wandb.plot() which you can call in one line on classification models.

Task: Identify 10 types of living things in photos with a simple convnet

The data consists of 10,000 training images and 2,000 validation images from the iNaturalist dataset, evenly distributed across 10 classes of living things like birds, insects, plants, and mammals (names given in Latin—so Aves, Insecta, Plantae, etc :). We will fine-tune a convolutional neural network already trained on ImageNet on this task: given a photo of a living thing, correctly classify it into one of the 10 classes. Note that in many experiments, I run toy version of the training using a very small subset of the data (100-1000 train, 500 val) and a small number of epochs (3-5) for illustration purposes.

Precision-Recall and ROC Curves

Log a Precision-Recall curve or an ROC curve (see links for more details):

wandb.plot.pr_curve(ground_truth, predictions, labels=None)
wandb.plot.roc_curve(ground_truth, predictions, labels=None)

You can log this whenever your training/testing code has access to:

  • a model's predicted scores (predictions) on a set of examples
  • the corresponding ground truth labels (ground_truth) for those examples
  • (optionally) a list of the class label names (labels=["cat", "dog", "bird"...] if 0 = cat, 1 = dog, 2 = bird, etc.)

This lets you disentangle a classification model's performance on different classes. The toy models below perform best on birds (teal) and mushrooms (green): these curves have the largest area under the precision-recall curve and are closest to the top left corner of the ROC curve (highest true positive rate and lowest false positive rate). The models seem to struggle the most with reptiles (gray) and amphibians (blue)—low area under PR curve and higher false positive rate for the same true positive rate). Note that the precision recall curve subsamples heavily (20 points per curve).




Run set
11


Custom Line Plot: Average Precision Across Runs

Log an average precision curve to compare across different versions of your model with the same arguments as above. For now this requires two lines of Python in addition to wandb.plot.line() (details in the Custom Line Plot report).

Instead of separating performance by class, this approach groups them together and focuses on overall binary correctness (micro-averaged precision). Below, average precision predictably increases as we increase the number of training samples. However, at 100 samples, the training is very noisy (teal vs green curve) over multiple runs of the same code, and even at 1000 vs 2000 samples (violet vs pink curve), the performance is very similar.




Avg Precision Runs
5


Custom Scatter Plot: Score Correlations

In the most general case, you can plot any arbitrary metric Y versus an arbitrary metric X. Assume these metrics are a series of numbers that have some correspondence and can be expressed as a list of (x, y) points: (x0, y0), (x1, y1), (x2, y2)... Once you have the metrics available as two lists x_values and y_values, you can plot y vs x by calling the code below (details in Custom Scatter Plots).

data = [[x, y] for (x, y) in zip(x_values, y_values)]
table = wandb.Table(data=data, columns = ["x", "y"])
wandb.log({"my_custom_plot_id" : wandb.plot.scatter(table, "x", "y", title="Custom Y vs X Scatter Plot")

Here I plot the confidence scores for two different labels on the same set of validation samples. You can see that the model is more likely to confuse reptiles and amphibians at low confidence (yellow points, many on a diagonal), slightly less likely to confuse insects and arachnids (blue points, fewer in the middle of the plot, more concentrated in bottom left), and unlikely to confuse birds for mushrooms (magenta points, overall low confidence for label=bird and a range of confidence for label=mushroom). These are of course generalizations from a single visualization, but they can help build intuition and reveal where the model is likely to confuse classes.




Custom Y vs X class confidence scores
3


Class distributions and composite histograms

Given the existing parameters, we can compare confidence scores across two classes more explicitly by calling

wandb.plots.class_distributions(predictions, class_names, ["Animalia", "Plantae"])

where the last argument is the two class labels to compare in (red bins, blue bins) order.

You can hover over the plot and scroll down to see a slider that will adjust the bin width for the overlayed histogram. This visualizes the distribution of scores for a particular class automatically. In the bottom left, a model variant is confident that most samples are NOT animals (vast majority of prediction scores are below 0.2) and uncertain about samples being plants (scores distributed fairly evenly across the range of possibilities). In the bottom right, the confidence scores for birds and mushrooms are fairly similar, with slightly higher confidence for samples NOT being mushrooms.

This method also has a more general version for two array of values:

wandb.plots.multi_histogram(red_bin_values, blue_bin_values)



This set of panels contains runs from a private project, which cannot be shown in this report



This set of panels contains runs from a private project, which cannot be shown in this report