Plot Precision Recall Curves
Usage and examples for wandb.plot.precision_recall()
Created on September 30|Last edited on October 13
Comment
Method: wandb.plots.precision_recall()
Log a Precision-Recall curve in one line:
wandb.log({"pr" : wandb.plots.precision_recall(ground_truth, predictions, labels=None, classes_to_plot=None)})
You can log this whenever your code has access to:
- a model's predicted scores (
predictions
) on a set of examples - the corresponding ground truth labels (
ground_truth
) for those examples - (optionally) a list of the labels/class names (
labels=["cat", "dog", "bird"...]
if label index 0 means cat, 1 = dog, 2 = bird, etc.) - (optionally) a subset list of the labels/classes to plot
Basic usage
I finetune a CNN to predict 10 classes of living things: plants, birds, insects, etc. In my validation step, I call
wandb.log({"my_custom_plot_id" : wandb.plots.precision_recall(ground_truth,
predictions, labels=["Amphibia", "Animalia", ... "Reptilia"])})
to produce the following curve for each run of my model (where each run logs to the same plot key, my_custom_plot_id
):
Toy CNN runs
24
Customized usage
To make this chart more legible, I can simply edit the built in wandb chart definition (or Vega spec), following the Vega visualization grammar and produce this new chart, where the differences between a single class are much easier to spot across runs:
Toy CNN runs
24
Add a comment