Working with FuncTorch: An Introduction
Working with JAX-like composable function transforms in PyTorch
Created on March 24|Last edited on April 1
Comment
Table of Contents (click to expand)
Introduction:
The 1.11 release of PyTorch came out recently and it added a few new important functionalities to our machine learning toolbox:
NOTE: FuncTorch is currently in beta which means is that the features generally work (unless otherwise documented) and the PyTorch team is currently working to bring this library forward. However, the APIs may change under user feedback and full coverage over PyTorch operations is not yet available. You can refer to the official documentation of FuncTorch here.
💡
For an introduction to TorchData with code, refer to the following report:
🤔 What is FuncTorch?
JAX-like composable function transforms for PyTorch
Essentially, FuncTorch is a library that provides composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.
🧐 Why Use Composable Function Transforms?
There are a number of use cases that are tricky to implement in vanilla PyTorch. A few examples:
- Computing per-sample-gradients (or other per-sample quantities)
- Running ensembles of models on a single machine
- Efficiently computing Jacobians and Hessians in a batched manner
- Efficiently implementing and utilizing stateless optimizers in a functional manner
NOTE: Composing vmap(), grad(), and vjp() transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.
💡
🎤 FuncTorch in Action
Now, let's demonstrate the usage of composable function transforms and the overall stateless approach to computation that FuncTorch enables us to perform. To do this, we'll build a very simple image classification model using the CIFAR-10 dataset.
✅ All the code used in this report can be found at https://github.com/soumik12345/functorch-examples
We're collapsing some of the code and setup below (namely the sections on model creation and using vanilla PyTorch), but you can expand those for a more thorough accounting of those steps.
🛋 Initial Setup
🏁 Initializing Weights and Biases
It's likely not a surprise but we'll be using Weights & Biases to track our experiments using the rich and interactive dashboard as well as store and version our dataset using Artifacts. Let's start here:
import wandbimport torch# Initialize a wandb runwandb.init(project="functorch-examples")# Set up experiment configsconfig = wandb.configconfig.batch_size = 64config.num_workers = 2config.learning_rate = 1e-2config.epochs = 5config.device = "cuda:0" if torch.cuda.is_available() else "cpu"config.classes = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
💿 Building PyTorch DataLoaders
🤖 Creating the Model
🍨 Training Using Vanilla PyTorch
🦾 Training using FuncTorch
As we have seen previously, in vanilla PyTorch, the model and the parameters are coupled together into a single entity. This prevents us from using composable function transforms in a stateless manner. To make our model stateless we can call functorch.make_functional on our model. This would decouple the parameters from our model and turn it into a purely stateless function.
from functorch import make_functionalmodel = Classifier().to(device=config.device)# params decoupled from the model. Now functional_model is a pure function.functional_model, params = make_functional(model)# we don't need to store grads for backprop# avoiding this step might result in a memory leakfor param in params:param.requires_grad_(False)train_metric = Accuracy().to(config.device)val_metric = Accuracy().to(config.device)
It is important to call param.requires_grad_(False) or detach gradients in the optimizer so you don't accumulate them in the backprop graph and keep adding to it instead of releasing that memory.
💡
🎓 Compute Loss and Gradients in a Stateless Manner
In contrast to vanilla PyTorch, we will create a function to compute the loss in a stateless manner, given a batch on inputs and respective labels. It's important that this function accepts the parameters, the input, and the label batch since we will be transforming over them.
def compute_stateless_loss(params, inputs, labels):outputs = functional_model(params, inputs)loss = F.cross_entropy(outputs, labels)return loss
Now that we have a function for computing loss in a stateless manner, we also need a function that can compute its gradient. We can use functorch.grad operator to transform the loss function into a function that computes the gradients.
from functorch import grad, grad_and_value# transformed functionscompute_gradients = grad(compute_stateless_loss)compute_gradients_and_loss = grad_and_value(compute_stateless_loss)
Note that while functorch.grad transforms a function to return just the gradients, functorch.grad_and_value returns a function to compute a tuple of the gradient and primal, or forward, computation.
💡
👾 Functional Optimizer
One of the advantages of FuncTorch is that since the parameters are completely decoupled from the model, we can create optimizers in a stateless manner as well. This can enable us to create optimizers easily as a function of the weights and gradients.
Note that PyTorch is missing a "functional optimizer API" (possibly coming soon?) so we're going to naively re-implement Stochastic Gradient Descent as a function.
# naive sgd optimizer implementationdef sgd_optimizer(weights, gradients, learning_rate):def step(weight, gradient):return weight - learning_rate * gradientreturn [step(weight, gradient) for weight, gradient in zip(weights, gradients)]
There's a feature request on the functorch repository to extend make_functional() or create make_functional_optimizer() to support torch.optim.Optimizer. You can find it at https://github.com/pytorch/functorch/issues/372
💡
🏋️ Training Loop
Now, let's combine everything to put together a training loop:
# functional train stepdef functional_step(params, inputs, labels):# --------------------------------------------------------------------# compute gradients and loss using the transformed function instead of# outputs = model(inputs)# loss = criterion(outputs, labels)# loss.backward()gradients, loss = compute_gradients_and_loss(params, inputs, labels)# --------------------------------------------------------------------# update the parameters using the functional optimizer# instead of `optimizer.step()`params = sgd_optimizer(params, gradients, config.learning_rate)# --------------------------------------------------------------------return params, loss# training functiondef train(functional_step, params, metric):running_loss = 0.metric.reset()# look over the train dataloaderfor i, data in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):# get the inputs; data is a list of [inputs, labels]inputs, labels = to_device(data, config.device)# perform functional train stepparams, loss = functional_step(params, inputs, labels)running_loss += loss.item()# forward propagation for updating the metricoutputs = functional_model(params, inputs)metric(outputs, labels)return params, running_loss, metric# validation functiondef validate(params, metric):running_loss = 0.metric.reset()# look over the validation dataloaderfor i, data in tqdm(enumerate(test_loader), total=len(test_loader), leave=False):inputs, labels = to_device(data, config.device)outputs = functional_model(params, inputs)loss = F.cross_entropy(outputs, labels)running_loss += loss.item()metric(outputs, labels)return running_loss, metric# loop over the dataloader multiple timesfor epoch in tqdm(range(config.epochs)):# perform training (single loop over the train dataloader)params, train_loss, train_metric = train(functional_step, params, train_metric)# perform validation (single loop over the validation dataloader)val_loss, val_metric = validate(params, val_metric)print(f'[{epoch}] train_loss: {train_loss / 2000:.3f}, train_accuracy: {train_metric.compute():.3f}, val_loss: {val_loss / 2000:.3f}, val_accuracy: {val_metric.compute():.3f}')wandb.log({"train_loss": train_loss,"val_loss": val_loss,"train_accuracy": train_metric.compute(),"val_accuracy": val_metric.compute()})
🚀 Comparing Vanilla PyTorch vs. FuncTorch
We can see from the following panels that although the results FuncTorch is comparable to vanilla PyTorch in terms of results, however, FuncTorch has a few caveats with respect to its performance:
- The GPU memory allocation in FuncTorch is significantly higher compared to vanilla PyTorch.
- The overall GPU utilization is roughly similar in both cases in spite of higher memory usage for FuncTorch. This is fixed now after calling param.requires_grad_(False) in the model parameters.
Run set
7
Run set
7
If you note a significant high GPU memory allocation in your functorch code, it might be that the params in the model have requires_grad=True, they'll hold onto the backpropagation graph and keep adding to it instead of releasing that memory. Hence we need to make sure that the params don't require_grad_. You can refer to this issue on the fuctorch repositpry for more info: https://github.com/pytorch/functorch/issues/639
💡
🎬 Conclusion
In this report we learned:
- An overview of FuncTorch
- The necessity of Composable Function Transforms
- Decouple parameters from the model
- Stateless loss and gradient computation using the functorch.grad operator
- Implementation of a stateless optimizer to update the parameters
- An image classification pipeline in vanilla PyTorch vs FuncTorch and their performance comparisons
💛 Similar Reports
How the TorchData API Works: a Tutorial with Code
Let's check the new way of building Datasets on latest PyTorch 1.11 with TorchData.
How to save and load models in PyTorch
This article is a machine learning tutorial on how to save and load your models in PyTorch using Weights & Biases for version control.
An Introduction To The PyTorch View Function
Demystify the View function in PyTorch and find a better way to design models.
Matrix Factorization from Scratch in JAX: Regularized SVD for Recommendation Systems
Bayesian Hyperparameter Search with Cross Validation for doubly-regularized Matrix Factorization on MovieLens.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.