Skip to main content

Torcheval: The New Metrics Library for PyTorch

Digging into the new metrics library for our PyTorch training loops.
Created on December 29|Last edited on January 6
The PyTorch torcheval library solves a lot of issues with metric computation on our training scripts. In this report, we'll look at what it solves, how it works, and how can integrate in your workflow.
You can follow along in this free colab as well!


Introduction

How we compute metrics for our training scripts in PyTorch has always been a choice. That's because there are no native metrics computation is provided by the torch library.
To solve this, people have created their own custom metrics packages or found ways of using the tested scikit-learn metrics computation inside PyTorch. That said, there are multiple caveats of doing custom metrics or leveraging scikit-learn. For instance, the tensors (torch.Tensor) object lives on a device like a GPU most of the time, as we want fast computation for our deep neural networks. So to compute metrics like accuracy, precision, or recall, you will need to make sure that the predictions and targets live on the same device.
😱 Scikit-learn requires numpy arrays as inputs, so you will need to bring those tensors back to CPU, convert them to numpy arrays and then pass them to the corresponding scikit-learn metric function.
Also, since we're performing mini-batch optimization most of the time, we can only compute the final dataset-level metrics once the entire dataset has passed the model. So we need to accumulate the metric values until the end, when we can compute the final metric.
Another critical issue that arises when training large models is distributed environment. In this scenario, to compute a metric across a distributed Dataloader, you have to gather the tensors and wait for the synchronization so you can differentiate between metrics on a particular rank vs global.
TorchEval helps. It also includes tools for evaluation, like FLOPs and summarization techniques for modules. Let's dig into what we can do below!


How does it work?

This is just another Python library. In other words, basically, you pip install torcheval and you're good to go!
It's worth noting that we have two ways of using the library: a functional API and a stateful API. Depending on your use case, you may want to use one over the other.

The functional API

You call the metrics on your tensors, and that's it. If you want to compute your metrics on a set of model predictions, this could be all you need. There are many available metrics, and more will surely be added in the future–take a look at what metrics are available here.
import torch
from torcheval.metrics.functional import binary_accuracy

preds = torch.tensor([0.1,0.7,0.6])
truth = torch.tensor([0,1,0])

binary_accuracy(preds, truth)

> tensor(0.6667)
This is not enough to instrument our training loops, as we want to accumulate the values over the dataset before computing the final metric.

The stateful API

When computing the metrics over a dataset on a mini-batch training loop, you want to accumulate the values while passing the batches through the model and delay the computation of the metric to the end of the epoch when the entire dataset has passed through the metric.
The stateful API enables you to do this by updating the metric and computing the final value:
from torcheval.metrics import BinaryAccuracy

accuracy = BinaryAccuracy()

preds1 = torch.tensor([0.1,0.7,0.6])
truth1 = torch.tensor([0,1,0])

# stores the values and accumulates into the metric instance
accuracy.update(preds1, truth1)

# computes the current value with the available data
accuracy.compute()

> tensor(0.6667)
But we can pass another batch of data to the metric:
preds2 = torch.tensor([0.4,0.9,0.1])
truth2 = torch.tensor([1,1,1])

accuracy.update(preds2, truth2)

accuracy.compute()

> tensor(0.5000)
This is equivalent to calling the functional metric over the full 2-batches dataset!
binary_accuracy(input=torch.cat([preds1, preds2]),
target=torch.cat([truth1, truth2]))

> tensor(0.5000)

Integrating torcheval with your training scripts


To use torcheval metrics in your training scripts, you must define the metrics you want to track before beginning the training process. You should create a separate instance of the metric for each dataset (training and validation) in order to keep track of the values separately for each dataset.
🐝 A cool trick is that you can also keep track of mean/median values, so we can use the Mean metric for keeping track of the average loss:
from torcheval.metrics import MulticlassAccuracy, Mean

train_acc = MulticlassAccuracy(device=config.device)
valid_acc = MulticlassAccuracy(device=config.device)

# another cool trick is keep track of the loss as a metric!
train_loss = Mean(device=config.device)
valid_loss = Mean(device=config.device)
Then, in your training loop, you will have to accumulate the values of the metrics by calling update
preds_b = model(images)
loss = loss_func(preds_b, labels)
train_step(loss)
# update metrics
train_loss.update(loss.detach(), weight=len(images)). # 🐝
train_acc.update(preds_b, labels) # 🐝
Whenever you want to show or log the value of the metrics, you have to call compute on the metric object. For example, when printing out or saving the values to your Weights : Biases project.
# log to W&B
wandb.log({"train_loss": train_loss.compute(),
"train_acc": train_acc.compute()})
Finally, you should reset the metrics at the end of the epoch so you don't keep accumulating and restart from zero at the next epoch. Remember that one epoch represents one complete pass of the dataset.
You can wrap the metrics reset procedure into a function like so:
def reset_metrics():
train_acc.reset()
valid_acc.reset()
train_loss.reset()
valid_loss.reset()

Integration with Weights and Biases (W&B) 🐝🪄




We are all about metrics! Saving all the relevant information about your ML projects and training logs is the core mission of W&B. Torcheval streamlines the computation of your metrics and makes your code more readable.
From a training script perspective, there are two places where you want to log metrics to W&B:
  • During the training, more precisely when iterating over the training DataLoader after computing the model forward pass, at this point, you will be interested in logging the training_loss, the relevant metrics and hyperparameters that are changing, like the schedulers learning rate, weight decay or other fancy parameters that you may be trying out!
# update metrics
train_loss.update(loss.detach(), weight=len(images))
train_acc.update(preds_b, labels)

# log to W&B
wandb.log({"train_loss": train_loss.compute(),
"train_acc": train_acc.compute(),
"learning_rate": scheduler.get_last_lr()[0]})

print(f"train_loss={train_loss.compute():2.3f}, train_acc={train_acc.compute():2.3f}")
At the end of the validation step, you want to keep track of your model performance on non-seen data:
for images, labels in valid_dataloader:
with torch.inference_mode():
preds_b = model(images)

# update metrics
valid_loss.update(loss_func(preds_b, labels).detach(), weight=len(images))
valid_acc.update(preds_b, labels)

# log to W&B, we log at the end of the validation
wandb.log({"valid_loss": valid_loss.compute(),
"train_acc": valid_acc.compute()})
Putting everything together you are able to look at your model performance and the metrics you computed during your training/evaluation:

Run set
1



Other libraries

There are multiple metrics solutions out there, and torcheval has taken inspiration and ideas from all of them to build an integrated PyTorch solution

torchmetrics

This library is probably the most widely used, as it has excellent testing and supports most of the features that torcheval tries to solve. It also has a much larger collection of metrics available now. Hence, it is weird to see a replacement library come up from the PyTorch team instead of joining forces and powering up torchmetrics as it already has extensive testing and adoption. I suppose that the PyTorch team wanted to have tight control over the integration and probably merge this (for the moment) separate package inside PyTorch itself.
There is a discussion ongoing on a GitHub issue addressing this:

Iterate on AI agents and models faster. Try Weights & Biases today.