Skip to main content

Log ROC Curves, Precision-Recall Curves, and Confusion Matrices With W&B

In this article, we explore how to log precision-recall curves, ROC curves, and confusion matrices natively using Weights & Biases.
Created on April 9|Last edited on October 9
This article explains how to log precision-recall (PR) and receiver operating characterist (ROC) curves, and confusion matrices natively using Weights & Biases. You are now also able to use our heat maps to create attention maps.


Table of Contents




ROC and PR Curves

ROC Curve

ROC curves plot true positive rate (y-axis) vs false positive rate (x-axis). The ideal score is a TPR = 1 and FPR = 0, which is the point on the top left. Typically we calculate the area under the ROC curve (AUC-ROC), and the greater the AUC-ROC the better.
Here we can see our model is slightly better at predicting the class Negative emotion, as evidenced by the larger area under the ROC.


Example

# ROC
wandb.log({"roc" : wandb.plot.roc_curve( ground_truth, predictions, \
labels=None, classes_to_plot=None)})

You can log this whenever your code has access to:

  • a model's predicted scores (predictions) on a set of examples
  • the corresponding ground truth labels (ground_truth) for those examples
  • (optionally) a list of the labels/ class names (labels=["cat", "dog", "bird"...] if label index 0 means cat, 1 = dog, 2 = bird, etc.)
  • (optionally) a subset (still in list format) of these labels to visualize on the plot


Precision-Recall Curves

Computes the tradeoff between precision and recall for different thresholds. A high area under the curve represents both high recall and high precision, where high precision relates to a low false positive rate, and high recall relates to a low false negative rate.
High scores for both show that the classifier is returning accurate results (high precision), as well as returning a majority of all positive results (high recall). PR curves are useful when the classes are very imbalanced.

Example

# ROC

wandb.log({"pr":wandb.plot.pr_curve(ground_truth, predictons,
labels=None, classes_to_plot=None)})
You can log this whenever your code has access to:
  • a model's predicted scores (predictions) on a set of examples
  • the corresponding ground truth labels (ground_truth) for those examples
  • (optionally) a list of the labels/class names (labels=["cat", "dog", "bird"...] if label index 0 means cat, 1 = dog, 2 = bird, etc.)
  • (optionally) a subset (still in list format) of the labels to visualize in the plot


Run set
28


Confusion Matrices

Computes the confusion matrix to evaluate the accuracy of a classification. It's useful for assessing the quality of model predictions and finding patterns in the predictions the model gets wrong.
The diagonal represents the predictions the model got right, i.e. where the actual label is equal to the predicted label.

Example

# Confusion Matrices

wandb.sklearn.plot_confusion_matrix(y_test, y_pred, nb.classes_)
  • y_true (arr): Test set labels.
  • y_probas (arr): Test set predicted probabilities.
  • labels (list): Named labels for target varible (y).

Heat Maps

Heatmaps that can be used to make attention maps, confusion matrices et all.
# ExplainText

'''
Arguments:
matrix_values (arr): 2D dataset of shape x_labels * y_labels, containing
heatmap values that can be coerced into an ndarray.
x_labels (list): Named labels for rows (x_axis).
y_labels (list): Named labels for columns (y_axis).
show_text (bool): Show text values in heatmap cells.
'''
wandb.log({'heatmap_with_text': wandb.plots.HeatMap(x_labels, y_labels, matrix_values, show_text=False)})
Here's an example of the attention maps for a Neural Machine Translation model that converts from English → French. We draw attention to maps at the 2nd, 20th epochs and 100th. Here we can see that the model starts out by not knowing which words to pay attention to (and uses <res> to predict all words, and slowly learns which ones to pay attention to over the course of the next 100 epochs.


Run set
3

Aleksander Salek
Aleksander Salek •  *
Is there a way to rearrange the rows and columns for the heatmap?
Reply
Niall O'Hara
Niall O'Hara •  
Very useful! Love the vega plots feature
Reply
Brady Zhou
Brady Zhou •  
colab data cell is broken, link should be !wget https://raw.githubusercontent.com/wandb/examples/master/examples/scikit/scikit-tweets/tweets.csv :)
1 reply
Aritra Roy Gosthipaty
Aritra Roy Gosthipaty •  
This is insanely helpful.
1 reply
Iterate on AI agents and models faster. Try Weights & Biases today.