Skip to main content

Working with FuncTorch: An Introduction

Working with JAX-like composable function transforms in PyTorch
Created on March 24|Last edited on April 1

Table of Contents (click to expand)

Introduction:

The 1.11 release of PyTorch came out recently and it added a few new important functionalities to our machine learning toolbox:
  • TorchData:  A library that enables building datasets with reusable building blocks called DataPipes
  • FuncTorch: A library that provides us with JAX-like composable function transforms for PyTorch
NOTE: FuncTorch is currently in beta which means is that the features generally work (unless otherwise documented) and the PyTorch team is currently working to bring this library forward. However, the APIs may change under user feedback and full coverage over PyTorch operations is not yet available. You can refer to the official documentation of FuncTorch here.
💡
For an introduction to TorchData with code, refer to the following report:




🤔 What is FuncTorch?

The official docs describe FuncTorch like so:
JAX-like composable function transforms for PyTorch
Essentially, FuncTorch is a library that provides composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.

🧐 Why Use Composable Function Transforms?

There are a number of use cases that are tricky to implement in vanilla PyTorch. A few examples:
  • Computing per-sample-gradients (or other per-sample quantities)
  • Running ensembles of models on a single machine
  • Efficiently batching tasks together in the inner-loop of Model-Agnostic Meta-Learning
  • Efficiently computing Jacobians and Hessians in a batched manner
  • Efficiently implementing and utilizing stateless optimizers in a functional manner
NOTE: Composing vmap(), grad(), and vjp() transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.
💡




🎤 FuncTorch in Action

Now, let's demonstrate the usage of composable function transforms and the overall stateless approach to computation that FuncTorch enables us to perform. To do this, we'll build a very simple image classification model using the CIFAR-10 dataset.
✅ All the code used in this report can be found at https://github.com/soumik12345/functorch-examples
We're collapsing some of the code and setup below (namely the sections on model creation and using vanilla PyTorch), but you can expand those for a more thorough accounting of those steps.

🛋 Initial Setup

🏁 Initializing Weights and Biases

It's likely not a surprise but we'll be using Weights & Biases to track our experiments using the rich and interactive dashboard as well as store and version our dataset using Artifacts. Let's start here:
import wandb
import torch

# Initialize a wandb run
wandb.init(project="functorch-examples")

# Set up experiment configs
config = wandb.config
config.batch_size = 64
config.num_workers = 2
config.learning_rate = 1e-2
config.epochs = 5
config.device = "cuda:0" if torch.cuda.is_available() else "cpu"
config.classes = (
'plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
)


💿 Building PyTorch DataLoaders

🤖 Creating the Model

🍨 Training Using Vanilla PyTorch

🦾 Training using FuncTorch

As we have seen previously, in vanilla PyTorch, the model and the parameters are coupled together into a single entity. This prevents us from using composable function transforms in a stateless manner. To make our model stateless we can call functorch.make_functional on our model. This would decouple the parameters from our model and turn it into a purely stateless function.
from functorch import make_functional

model = Classifier().to(device=config.device)
# params decoupled from the model. Now functional_model is a pure function.
functional_model, params = make_functional(model)

# we don't need to store grads for backprop
# avoiding this step might result in a memory leak
for param in params:
param.requires_grad_(False)

train_metric = Accuracy().to(config.device)
val_metric = Accuracy().to(config.device)
It is important to call param.requires_grad_(False) or detach gradients in the optimizer so you don't accumulate them in the backprop graph and keep adding to it instead of releasing that memory.
💡

🎓 Compute Loss and Gradients in a Stateless Manner

In contrast to vanilla PyTorch, we will create a function to compute the loss in a stateless manner, given a batch on inputs and respective labels. It's important that this function accepts the parameters, the input, and the label batch since we will be transforming over them.
def compute_stateless_loss(params, inputs, labels):
outputs = functional_model(params, inputs)
loss = F.cross_entropy(outputs, labels)
return loss
Now that we have a function for computing loss in a stateless manner, we also need a function that can compute its gradient. We can use functorch.grad operator to transform the loss function into a function that computes the gradients.
from functorch import grad, grad_and_value

# transformed functions
compute_gradients = grad(compute_stateless_loss)
compute_gradients_and_loss = grad_and_value(compute_stateless_loss)
Note that while functorch.grad transforms a function to return just the gradients, functorch.grad_and_value returns a function to compute a tuple of the gradient and primal, or forward, computation.
💡

👾 Functional Optimizer

One of the advantages of FuncTorch is that since the parameters are completely decoupled from the model, we can create optimizers in a stateless manner as well. This can enable us to create optimizers easily as a function of the weights and gradients.
Note that PyTorch is missing a "functional optimizer API" (possibly coming soon?) so we're going to naively re-implement Stochastic Gradient Descent as a function.
# naive sgd optimizer implementation
def sgd_optimizer(weights, gradients, learning_rate):
def step(weight, gradient):
return weight - learning_rate * gradient

return [step(weight, gradient) for weight, gradient in zip(weights, gradients)]
There's a feature request on the functorch repository to extend make_functional() or create make_functional_optimizer() to support torch.optim.Optimizer. You can find it at https://github.com/pytorch/functorch/issues/372
💡


🏋️ Training Loop

Now, let's combine everything to put together a training loop:
# functional train step
def functional_step(params, inputs, labels):
# --------------------------------------------------------------------
# compute gradients and loss using the transformed function instead of
# outputs = model(inputs)
# loss = criterion(outputs, labels)
# loss.backward()
gradients, loss = compute_gradients_and_loss(params, inputs, labels)
# --------------------------------------------------------------------
# update the parameters using the functional optimizer
# instead of `optimizer.step()`
params = sgd_optimizer(params, gradients, config.learning_rate)
# --------------------------------------------------------------------
return params, loss

# training function
def train(functional_step, params, metric):
running_loss = 0.
metric.reset()
# look over the train dataloader
for i, data in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = to_device(data, config.device)
# perform functional train step
params, loss = functional_step(params, inputs, labels)
running_loss += loss.item()
# forward propagation for updating the metric
outputs = functional_model(params, inputs)
metric(outputs, labels)
return params, running_loss, metric

# validation function
def validate(params, metric):
running_loss = 0.
metric.reset()
# look over the validation dataloader
for i, data in tqdm(enumerate(test_loader), total=len(test_loader), leave=False):
inputs, labels = to_device(data, config.device)
outputs = functional_model(params, inputs)
loss = F.cross_entropy(outputs, labels)
running_loss += loss.item()
metric(outputs, labels)
return running_loss, metric

# loop over the dataloader multiple times
for epoch in tqdm(range(config.epochs)):
# perform training (single loop over the train dataloader)
params, train_loss, train_metric = train(functional_step, params, train_metric)
# perform validation (single loop over the validation dataloader)
val_loss, val_metric = validate(params, val_metric)
print(f'[{epoch}] train_loss: {train_loss / 2000:.3f}, train_accuracy: {train_metric.compute():.3f}, val_loss: {val_loss / 2000:.3f}, val_accuracy: {val_metric.compute():.3f}')
wandb.log({
"train_loss": train_loss,
"val_loss": val_loss,
"train_accuracy": train_metric.compute(),
"val_accuracy": val_metric.compute()
})




🚀 Comparing Vanilla PyTorch vs. FuncTorch

We can see from the following panels that although the results FuncTorch is comparable to vanilla PyTorch in terms of results, however, FuncTorch has a few caveats with respect to its performance:
  • The GPU memory allocation in FuncTorch is significantly higher compared to vanilla PyTorch.
  • The overall GPU utilization is roughly similar in both cases in spite of higher memory usage for FuncTorch. This is fixed now after calling param.requires_grad_(False) in the model parameters.

Run set
7



Run set
7

If you note a significant high GPU memory allocation in your functorch code, it might be that the params in the model have requires_grad=True, they'll hold onto the backpropagation graph and keep adding to it instead of releasing that memory. Hence we need to make sure that the params don't require_grad_. You can refer to this issue on the fuctorch repositpry for more info: https://github.com/pytorch/functorch/issues/639
💡




🎬 Conclusion

In this report we learned:
  • An overview of FuncTorch
  • The necessity of Composable Function Transforms
  • Decouple parameters from the model
  • Stateless loss and gradient computation using the functorch.grad operator
  • Implementation of a stateless optimizer to update the parameters
  • An image classification pipeline in vanilla PyTorch vs FuncTorch and their performance comparisons



💛 Similar Reports



Iterate on AI agents and models faster. Try Weights & Biases today.