Plot Precision Recall Curves

Usage and examples for wandb.plot.pr_curve(). Made by Stacey Svetlichnaya using Weights & Biases
Stacey Svetlichnaya

Method: wandb.plot.pr_curve()

Log a Precision-Recall curve in one line:
wandb.log({"pr" : wandb.plot.pr_curve(ground_truth, predictions, labels=None, classes_to_plot=None)})
You can log this whenever your code has access to:

Try it yourself via Colab →

Basic usage

I finetune a CNN to predict 10 classes of living things: plants, birds, insects, etc. In my validation step, I call
wandb.log({"my_custom_plot_id" : wandb.plot.pr_curve(ground_truth, predictions, labels=["Amphibia", "Animalia",..."Reptilia"])})
to produce the following curve for each run of my model (where each run logs to the same plot key, my_custom_plot_id). Scroll over the chart area to zoom in, click+drag to pan, and hover to see more detail about a line.

Customized usage

To make this chart more legible, I can simply edit the built in wandb chart definition (or Vega spec), following the Vega visualization grammar and produce this new chart, where the differences between a single class are much easier to spot across runs. Now each line's color represents one of my 10 classes, and each line's stroke type—solid, dash, dot—represents one of my three experiments with different numbers of epochs/training examples. You can hover over the top right corner of the chart and click on the "eye" icon to see the full Vega spec.
See the full definition of wandb.plot.pr_curve() →