Wave a wandb.plot() to Visualize
Easy visualizations for classification models
This report covers some of the built-in visualizations in wandb.plot() which you can call in one line on classification models.
Task: Identify 10 types of living things in photos with a simple convnet
The data consists of 10,000 training images and 2,000 validation images from the iNaturalist dataset, evenly distributed across 10 classes of living things like birds, insects, plants, and mammals (names given in Latin—so Aves, Insecta, Plantae, etc :). We will fine-tune a convolutional neural network already trained on ImageNet on this task: given a photo of a living thing, correctly classify it into one of the 10 classes. Note that in many experiments, I run toy version of the training using a very small subset of the data (100-1000 train, 500 val) and a small number of epochs (3-5) for illustration purposes.
Chart interactivity
In all of the preset custom charts, you can:
- zoom and pan inside the chart to focus on the area of interest
- hover your cursor over lines/points to see detailed information about them
- expand the Run Set (gray tab below the charts) and toggle individual runs on/off via the "eye" icon to the left of each run name
Precision-Recall and ROC curves
Log a Precision-Recall curve or an ROC curve (see links for more details):
wandb.plot.pr_curve(ground_truth, predictions, labels=None)
wandb.plot.roc_curve(ground_truth, predictions, labels=None)
You can log this whenever your training/testing code has access to:
- a model's predicted scores (
predictions
) on a set of examples - the corresponding ground truth labels (
ground_truth
) for those examples - (optionally) a list of the labels/class names (
labels=["cat", "dog", "bird"...]
if 0 = cat, 1 = dog, 2 = bird, etc.)
This lets you disentangle a classification model's performance on different classes. The toy models below perform best on birds (teal) and mushrooms (green): these curves have the largest area under the precision-recall curve and are closest to the top left corner of the ROC curve (highest true positive rate and lowest false positive rate). The models seem to struggle the most with reptiles (gray) and amphibians (blue)—low area under PR curve and higher false positive rate for the same true positive rate). Note that the precision recall curve subsamples heavily (20 points per curve).
Custom line plot: Average precision across runs
Log an average precision curve to compare across different versions of your model with the same arguments as above. For now this requires two lines of Python in addition to wandb.plot.line()
(details in the Custom Line Plot report).
Instead of separating performance by class, this approach groups them together and focuses on overall binary correctness (micro-averaged precision). Below, average precision predictably increases as we increase the number of training samples. However, at 100 samples, the training is very noisy (teal vs green curve) over multiple runs of the same code, and even at 1000 vs 2000 samples (violet vs pink curve), the performance is very similar.
Custom scatter plot: Explore correlations
In the most general case, you can plot any arbitrary metric Y versus an arbitrary metric X. Assume these metrics are a series of numbers that have some correspondence and can be expressed as a list of (x, y) points: (x0, y0), (x1, y1), (x2, y2)... Once you have the metrics available as two lists x_values and y_values, you can plot y vs x by calling the code below (details in Custom Scatter Plots).
data = [[x, y] for (x, y) in zip(x_values, y_values)]
table = wandb.Table(data=data, columns = ["x", "y"])
wandb.log({"my_custom_plot_id" : wandb.plots.scatter(table, "x", "y", title="Custom Y vs X Scatter Plot")
Here I plot the confidence scores for two different labels on the same set of validation samples. You can see that the model is more likely to confuse reptiles and amphibians at low confidence (yellow points, many on a diagonal), slightly less likely to confuse insects and arachnids (blue points, fewer in the middle of the plot, more concentrated in bottom left), and unlikely to confuse birds for mushrooms (magenta points, overall low confidence for label=bird and a range of confidence for label=mushroom). These are of course generalizations from a single visualization, but they can help build intuition and reveal where the model is likely to confuse classes.
Class histograms: Confidence scores
Log a custom histogram—sort list of values into bins by count/frequency of occurrence—natively in a few lines with Custom Histograms. Let's say I have a list of prediction confidence scores (scores
) and want to visualize their distribution:
data = [[s] for s in scores]
table = wandb.Table(data=data, columns=["scores"])
wandb.log({'my_histogram': wandb.plot.histogram(table, "scores")})
I vary NT, the number of training examples and E, the number of training epochs for each run. Both numbers are tiny for illustration purposes. When these are too small, a toy model only outputs very low confidence scores (<0.1). With increasing epochs and numbers of training examples, we start to see more high confidence scores and some intermediate scores for the model's prediction confidence (in this case, identifying the living thing in an image as a bird as opposed to one of the other nine classes: plant, mammal, etc).
Custom bar charts: Per-class metrics like precision
Log a custom bar chart—a list of labeled values as bars—natively in a few lines:
data = [[label, val] for (label, val) in zip(labels, values)]
table = wandb.Table(data=data, columns = ["label", "value"])
wandb.log({"my_bar_chart_id" : wandb.plot.bar(table, "label", "value", title="Custom Bar Chart")
You can use this to log arbitrary bar charts. Note that the number of labels and values in the lists must match exactly (i.e. each data point must have both). In the example below, you can display/hide individual runs from the bar chart by clicking on the "eye" icon to the left of each run name under "Toy CNN variants".
I vary NT, the number of training examples and E, the number of training epochs for each run. Both numbers are tiny for illustration purposes. With increasing epochs and numbers of training examples, the precision increases for all classes. Anecdotally, birds (Aves) are the least visually diverse and animals (Animalia) are probably the most visually diverse classes in the dataset, which matches the fast improvement of the former and lagging performance of the latter.
Create your own!
You can create your own preset and log data to it directly from a script. Click on the gear icon in the top right to see a slider that adjusts the bin size, giving a slightly different perspective on the data.
This chart shows confidence scores of an image classification CNN for two different labels on the same images, specifically for identifying the creature in each image as an amphibian versus a reptile. This toy model is a bit more certain about amphibians (narrower, higher peak) than reptiles (broader range of scores).