Skip to main content

Plot Precision Recall Curves

Usage and examples for wandb.plot.precision_recall()
Created on September 30|Last edited on October 13

Method: wandb.plots.precision_recall()

Log a Precision-Recall curve in one line:

wandb.log({"pr" : wandb.plots.precision_recall(ground_truth, predictions, labels=None, classes_to_plot=None)})

You can log this whenever your code has access to:

  • a model's predicted scores (predictions) on a set of examples
  • the corresponding ground truth labels (ground_truth) for those examples
  • (optionally) a list of the labels/class names (labels=["cat", "dog", "bird"...] if label index 0 means cat, 1 = dog, 2 = bird, etc.)
  • (optionally) a subset list of the labels/classes to plot


Basic usage

I finetune a CNN to predict 10 classes of living things: plants, birds, insects, etc. In my validation step, I call

wandb.log({"my_custom_plot_id" : wandb.plots.precision_recall(ground_truth, 
                         predictions, labels=["Amphibia", "Animalia", ... "Reptilia"])})

to produce the following curve for each run of my model (where each run logs to the same plot key, my_custom_plot_id):




Toy CNN runs
24


Customized usage

To make this chart more legible, I can simply edit the built in wandb chart definition (or Vega spec), following the Vega visualization grammar and produce this new chart, where the differences between a single class are much easier to spot across runs:




Toy CNN runs
24