Log ROC Curves, Precision-Recall Curves, and Confusion Matrices With W&B
In this article, we explore how to log precision-recall curves, ROC curves, and confusion matrices natively using Weights & Biases.
Created on April 9|Last edited on October 9
Comment
This article explains how to log precision-recall (PR) and receiver operating characterist (ROC) curves, and confusion matrices natively using Weights & Biases. You are now also able to use our heat maps to create attention maps.
Table of Contents
ROC and PR Curves
ROC Curve
ROC curves plot true positive rate (y-axis) vs false positive rate (x-axis). The ideal score is a TPR = 1 and FPR = 0, which is the point on the top left. Typically we calculate the area under the ROC curve (AUC-ROC), and the greater the AUC-ROC the better.
Here we can see our model is slightly better at predicting the class Negative emotion, as evidenced by the larger area under the ROC.
Example
# ROCwandb.log({"roc" : wandb.plot.roc_curve( ground_truth, predictions, \labels=None, classes_to_plot=None)})
You can log this whenever your code has access to:
- a model's predicted scores (predictions) on a set of examples
- the corresponding ground truth labels (ground_truth) for those examples
- (optionally) a list of the labels/ class names (labels=["cat", "dog", "bird"...] if label index 0 means cat, 1 = dog, 2 = bird, etc.)
- (optionally) a subset (still in list format) of these labels to visualize on the plot
Precision-Recall Curves
Computes the tradeoff between precision and recall for different thresholds. A high area under the curve represents both high recall and high precision, where high precision relates to a low false positive rate, and high recall relates to a low false negative rate.
High scores for both show that the classifier is returning accurate results (high precision), as well as returning a majority of all positive results (high recall). PR curves are useful when the classes are very imbalanced.
Example
# ROCwandb.log({"pr":wandb.plot.pr_curve(ground_truth, predictons,labels=None, classes_to_plot=None)})
You can log this whenever your code has access to:
- a model's predicted scores (predictions) on a set of examples
- the corresponding ground truth labels (ground_truth) for those examples
- (optionally) a list of the labels/class names (labels=["cat", "dog", "bird"...] if label index 0 means cat, 1 = dog, 2 = bird, etc.)
- (optionally) a subset (still in list format) of the labels to visualize in the plot
Run set
28
Confusion Matrices
Computes the confusion matrix to evaluate the accuracy of a classification. It's useful for assessing the quality of model predictions and finding patterns in the predictions the model gets wrong.
The diagonal represents the predictions the model got right, i.e. where the actual label is equal to the predicted label.
Example
# Confusion Matriceswandb.sklearn.plot_confusion_matrix(y_test, y_pred, nb.classes_)
- y_true (arr): Test set labels.
- y_probas (arr): Test set predicted probabilities.
- labels (list): Named labels for target varible (y).
Heat Maps
Heatmaps that can be used to make attention maps, confusion matrices et all.
# ExplainText'''Arguments:matrix_values (arr): 2D dataset of shape x_labels * y_labels, containingheatmap values that can be coerced into an ndarray.x_labels (list): Named labels for rows (x_axis).y_labels (list): Named labels for columns (y_axis).show_text (bool): Show text values in heatmap cells.'''wandb.log({'heatmap_with_text': wandb.plots.HeatMap(x_labels, y_labels, matrix_values, show_text=False)})
Here's an example of the attention maps for a Neural Machine Translation model that converts from English → French. We draw attention to maps at the 2nd, 20th epochs and 100th. Here we can see that the model starts out by not knowing which words to pay attention to (and uses <res> to predict all words, and slowly learns which ones to pay attention to over the course of the next 100 epochs.
Run set
3
Add a comment
Is there a way to rearrange the rows and columns for the heatmap?
Reply
Very useful! Love the vega plots feature
Reply
colab data cell is broken, link should be
!wget https://raw.githubusercontent.com/wandb/examples/master/examples/scikit/scikit-tweets/tweets.csv
:)
1 reply
This is insanely helpful.
1 reply
Iterate on AI agents and models faster. Try Weights & Biases today.