Writing a Training Loop in JAX and Flax
In this article, we explore an end-to-end training and evaluation pipeline in JAX, Flax, and Optax for image classification, using W&B to track experiments.
Created on July 13|Last edited on March 17
Comment

In this article, we'll attempt to create a simple training and evaluation loop for a baseline image classification task using JAX, Flax, and Optax in an end-to-end manner.
We also explore how the Flax-based training and evaluation pipeline differs from the same written in existing popular frameworks such as Tensorflow and PyTorch.
Lastly, we'll demonstrate how to take advantage of experiment-tracking using Weights & Biases for a Flax-based pipeline. Now, let's take a look at the individual frameworks before jumping into the code. We'll start with JAX.
Table of Contents
What Is JAX?What Is Flax?What Is Optax?The DatasetIt's Time for ModelingAnatomy of a Training LoopTime for an Exercise...Just In Time Compilation in JAXConclusion
What Is JAX?

JAX is an accelerated computation framework that brings together Autograd and XLA for high-performance machine learning research. It provides a simple NumPy and SciPy-like interface for fast scientific computing and machine learning which can be compiled to run on CPUs, GPUs, and TPUs. JAX also provides additional APIs for special accelerator ops when needed.
JAX supports the just-in-time (JIT) compilation of Python functions into XLA-optimized kernels using a one-function API. Due to its functional programming paradigm, Jax allows us to use composable transforms to transform a function without modifying it, including but not limited to:
- Parallelization: Automatically parallelize code across multiple accelerators (including across hosts, e.g., for TPU pods) using jax.pmap
What Is Flax?

Flax is a high-performance neural network library for JAX designed for flexibility and performance. The design philosophy of Flax is best explained by the following quote from the official docs:
Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.
Flax comes with a bunch of higher-level APIs that make it easy for us to create training pipelines in a flexible manner, including but not limited to:
- flax.linen, the neural network API that enables us to easily define neural network models in a flexible and pythonic manner (discussed below).
- Serialization utilities for all Flax classes that carry a state (such as the models and optimizers).
- Training utilities associated with training pipelines include handling checkpoints, train state, and learning rate schedule.
Why Should You Care?
But why Flax, you ask? Don't we have Tensorflow and PyTorch already? Don't we have enough? Well, let us try and convince you why you should try the (JAX + Flax) ecosystem.
Most of us agree that this next era of deep learning will be about scaling. Researchers and industrial labs will continue to push the boundaries of current hardware and scale up. Models are getting better, they want to see more data, and they have higher training instability. All of this clearly can't happen on a single accelerator, so we need to be looking at multi-accelerator systems.
Wouldn't it be nice if you could write your code in an accelerator-agnostic way? This isn't possible with most current systems. In PyTorch, you need to split your model and then move your tensors onto multiple devices and then somehow figure out how to sync metrics, log concurrently––the list goes on. Tensorflow makes it a bit easier with the Strategy submodule but makes it harder to do fine-grained modifications without significant extra code.
Enter (JAX+Flax). If you use Flax, you write your forward pass for one single machine by writing a function and then convert that function using a JAX transformation and have it run on any number of devices. Using considerably fewer modifications, you can run your code on any setup (CPU, GPU, TPU, etc.).
💡
What Is Optax?
Optax is a gradient processing and optimization library for JAX. It was designed by Deepmind to facilitate research by providing building blocks that can be easily recombined in custom manners. Optax focuses on implementations of simple, well-tested, and efficient implementations of small composable building blocks (such as optimizers and loss functions) that can be effectively combined into custom solutions. In this article, we would primarily be using Optax for our optimizer algorithm.
This article was written as a Weights & Biases Report which is a project management and collaboration tool for machine learning projects. Reports let you organize and embed visualizations, describe your findings, share updates with collaborators, and more. To know more about reports, check out Collaborative Reports.
💡
The Dataset
JAX or Flax doesn't yet have a native API for building data loading pipelines. You could use either of torch.utils.data API and Torchvision datasets from PyTorch or tf.data API and Tensorflow Datasets from Tensorflow for the purpose of building an input pipeline.
However, most JAX practitioners prefer to use the tf.data API for building data loading pipelines for JAX and Flax-based machine learning workflow. In this article, we'll build a simple data-loading pipeline for the CIFAR-10 dataset using Tensorflow Datasets for Image Classification.
You may refer to this Github issue that discussed the preferred way of constructing data loading pipelines for machine learning workflows in JAX.
💡
TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow, JAX or other Python ML frameworks. All datasets are exposed as tf.data.Datasets , enabling easy-to-use and high-performance input pipelines. Tensorflow Datasets also provide a large list of ready-to-use datasets. For a getting a quick introduction to the general usage of Tensorflow Datasets, one can refer to their official quick start guide.
💡
Setting Up a Weights & Biases Run
Let's call wandb.init to initialize a new job. This creates a new run in Weights & Biases and launches a background process to sync data. We will also sync all the configs of our experiments with the W&B run, which makes it far easier for us to reproduce the results of the experiment later.
# Initializing a Weights & Biases Runwandb.init(project="simple-training-loop",entity="jax-series",job_type="simple-train-loop")# Setting the configs of our experiment using `wandb.config`.# This way, Weights & Biases automatcally syncs the configs of# our experiment which could be used to reproduce the results of an experiment.config = wandb.configconfig.seed = 42config.batch_size = 64config.validation_split = 0.2config.pooling = "avg"config.learning_rate = 1e-4config.epochs = 15

You can visit any run on your Weights & Biases project and refer to the configs
A simple data loading pipeline in Tensorflow
0
It's Time for Modeling
Initially introduced after v0.4.0 the linen API makes it easy to build "modules" for various deep learning methods easily while also maintaining and respecting the functional paradigm and providing excellent support for JAX transformations such as vmap, remat or scan. Linen was created to allow developers to still create python objects (such as dataclasses or Object -Oriented based Subclasses) but also go about the functional single-method manner.
If you come from the Keras or PyTorch ecosystem, the Linen API has easy-to-understand analogies:
- In PyTorch or Keras subclassed models we define all submodules and layers under the __init__ method. Flax has a similar method called setup() that we override.
- Instead of a forward in PyTorch models or a call in Keras, Flax has a __call__ method.
Let's now define a very simple convolution-based neural network for image classification. Instead of some famous architecture, we'll create a simple custom architecture by subclassing linen.Module.
class CNN(nn.Module):pool_module: Callable = nn.avg_pooldef setup(self):self.conv_1 = nn.Conv(features=32, kernel_size=(3, 3))self.conv_2 = nn.Conv(features=32, kernel_size=(3, 3))self.conv_3 = nn.Conv(features=64, kernel_size=(3, 3))self.conv_4 = nn.Conv(features=64, kernel_size=(3, 3))self.conv_5 = nn.Conv(features=128, kernel_size=(3, 3))self.conv_6 = nn.Conv(features=128, kernel_size=(3, 3))self.dense_1 = nn.Dense(features=1024)self.dense_2 = nn.Dense(features=512)self.dense_output = nn.Dense(features=10)@nn.compactdef __call__(self, x):x = nn.relu(self.conv_1(x))x = nn.relu(self.conv_2(x))x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))x = nn.relu(self.conv_3(x))x = nn.relu(self.conv_4(x))x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))x = nn.relu(self.conv_5(x))x = nn.relu(self.conv_6(x))x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))x = x.reshape((x.shape[0], -1))x = nn.relu(self.dense_1(x))x = nn.relu(self.dense_2(x))return self.dense_output(x)
Initializing the Module
Now that we have defined the CNN Module, we need to initialize it. However, unlike Tensorflow or PyTorch, the parameters of a Flax Module are not stored with the models themselves. We would need to initialize parameters by calling the init function, using a PRNG Key and a dummy input parameter with the same shape as the expected input:
rng = jax.random.PRNGKey(config.seed) # PRNG Keyx = jnp.ones(shape=(config.batch_size, 32, 32, 3)) # Dummy Inputmodel = CNN(pool_module=MODULE_DICT[config.pooling]) # Instantiate the Modelparams = model.init(rng, x) # Initialize the parametersjax.tree_map(lambda x: x.shape, params) # Check the parameters
Output (Click to Expand)
Q) But Why Do We Need To Initialize the Module?
A) How nice of you to ask. As mentioned above, the JAX ecosystem is based on a functional paradigm with pure functions. Our training and evaluation steps are simply functions that we will periodically call. But how do we update the parameters if there is no global context? Well, simply, we introduce an intermediate variable that is transferred/passed at each function call.
That's what the __init__ function does. It takes the module and returns the updated variables. After we initialize the model, we'll use the variables to create a TrainState, a utility class for handling parameter and gradient updates. This is a key feature of the new Flax version. Instead of initializing the model again and again with new variables, we update the "state" of the model and pass this as inputs to functions.
import optaxfrom flax.training import train_statedef init_train_state(model, random_key, shape, learning_rate) -> train_state.TrainState:# Initialize the Modelvariables = model.init(random_key, jnp.ones(shape))# Create the optimizeroptimizer = optax.adam(learning_rate)# Create a Statereturn train_state.TrainState.create(apply_fn = model.apply,tx=optimizer,params=variables['params'])state = init_train_state(model, rng, (config.batch_size, 32, 32, 3), config.learning_rate)
- apply_fn: Typically, the apply method of the Flax Module
- tx: Typically, an Optax Optimizer
- params: The parameters from the initialized variable dictionary (for an example of the Initialized Variable Dict, you can refer to the above output subsection).
Anatomy of a Training Loop
Before going through a training loop in JAX and Flax let's quickly recap what a training loop looks like in PyTorch and Tensorflow:
A comparison between the anatomy of a typical training loop written in Tensorflow and PyTorch
0
A JAX-Based Train and Validation Loop, Explained
Now, let's take a look at a typical training and validation loop written using JAX and Flax:
def train_and_evaluate(train_dataset, eval_dataset, test_dataset, state, epochs):num_train_batches = tf.data.experimental.cardinality(train_dataset)num_eval_batches = tf.data.experimental.cardinality(eval_dataset)num_test_batches = tf.data.experimental.cardinality(test_dataset)for epoch in tqdm(range(1, epochs + 1)):best_eval_loss = 1e6# ============== Training ============== #train_batch_metrics = []train_datagen = iter(tfds.as_numpy(train_dataset))for batch_idx in range(num_train_batches):batch = next(train_datagen)state, metrics = train_step(state, batch)train_batch_metrics.append(metrics)train_batch_metrics = accumulate_metrics(train_batch_metrics)# ============== Validation ============= #eval_batch_metrics = []eval_datagen = iter(tfds.as_numpy(eval_dataset))for batch_idx in range(num_eval_batches):batch = next(eval_datagen)metrics = eval_step(state, batch)eval_batch_metrics.append(metrics)eval_batch_metrics = accumulate_metrics(eval_batch_metrics)# Log Metrics to Weights & Biaseswandb.log({"Train Loss": train_batch_metrics['loss'],"Train Accuracy": train_batch_metrics['accuracy'],"Validation Loss": eval_batch_metrics['loss'],"Validation Accuracy": eval_batch_metrics['accuracy']}, step=epoch)return state
Now let's walk through the aforementioned training loop in detail and go through the training and evaluation steps:
Train Step
Let's see what an example train_step looks like:
@jax.jitdef train_step(state: train_state.TrainState, batch: jnp.ndarray):image, label = batchdef loss_fn(params):logits = state.apply_fn({'params': params}, image)loss = cross_entropy_loss(logits=logits, labels=label)return loss, logitsgradient_fn = jax.value_and_grad(loss_fn, has_aux=True)(_, logits), grads = gradient_fn(state.params)state = state.apply_gradients(grads=grads)metrics = compute_metrics(logits=logits, labels=label)return state, metrics
- Any train step should take two basic parameters; the state and the batch (or whatever format the input is) in question.
- We usually define the loss function within this function as best practice. We get the logits from the model, using the apply_fn from the TrainState (which is just the apply method of the model) by passing the parameters and the input. We then compute the loss by using the logits and input and return the loss as well as the logits (this is key).
- Next, we transform the function using jax.value_and_grad() transformation. Instead of jax.grad() which just creates a function that returns the derivative of the function. We use jax.value_and_grad() which returns the gradient as well as the evaluation of the function. (Notice the has_aux parameter, we set this to True because the loss function returns the loss as well as the logits, an auxiliary object)
- We then calculate the gradients and obtain the logits by passing in the parameters of the state. Notice how the function returns both the gradients and the logits (because we used jax.value_and_grad() instead of jax.grad()) we'll later need these logits to calculate metrics after the step
- We then essentially perform backpropagation by updating the TrainState using the calculated gradients by using the .apply_gradients() method
- Calculate the metrics using a utility compute_metrics function.
Utility function to compute metrics (Click to Expand)
Evaluation Step
Now that we've seen what a training step looks like, let's see what an eval_step looks like:
@jax.jitdef eval_step(state, batch):image, label = batchlogits = state.apply_fn({'params': state.params}, image)return compute_metrics(logits=logits, labels=label)
Similar to our train_step this function also takes the state and the batch. We simply perform a forward pass using the data and obtain the logits and then compute the corresponding metrics. As this is the eval_step we don't compute the gradients or update the parameters of the TrainState.
After each step, we simply accumulate the metrics from all the steps to compute the final metrics for the epoch and log the metrics to Weights & Biases and obtain beautiful-looking plots.
Weights & Biases enables us to visualize our metrics and other experiment results as beautiful interactive plots of our metrics
11
Time for an Exercise...
Now that we have known how to write a simple training and validation loop using JAX and Flax let us get our hands dirty with a little exercise.
- Go ahead and visit the following notebook on Google Colab
- The notebook walks you through the process of training a simple image classification model in JAX and Flax.
- There are 4 tasks in the notebook you need to solve in order to train the model.
- The code is integrated with Weights & Biases Stellar experiment tracking features. Once you complete all the tasks and run the notebook, the results of your experiment will be automatically logged into the leaderboard.
Just In Time Compilation in JAX
While writing the train_step and eval_step functions in the previous section, you must have noticed that we decorate these functions using jax.jit, and you might be wondering why we used this particular decorator... 🧐
The jax.jit() transform performs the Just In Time (JIT) compilation of a JAX Python function so it can be executed efficiently in the XLA compiler. As discussed previously, JAX enables us to write our code in an accelerator-agnostic manner so that it can be executed on any accelerator with the same source code. what we want to do is give the XLA compiler as much code as possible, so the compiler can fully optimize the code. This is the purpose that is served by applying the jax.jit transform on a JAX Python function. The function is compiled just in time by the XLA compiler.
The first time we run the just-in-time compiled function of the state and input batch, the Pytorch source code is traced and converted into a simple, intermediate language called jaxpr. The jaxpr is then compiled using XLA into very efficient code optimized for our respective accelerator (GPU or TPU). This enables us to run the same code efficiently for any given accelerator.
Note that instead of applying the jax.jit transforms as a decorator, we can also apply it in the following manner:
def train_step(state: train_state.TrainState, batch: jnp.ndarray):...# Now this jitted version of the train_step function would be used in the training loopjitted_train_step = jax.jit(train_step)
XLA or Accelerated Linear Algebra is a domain-specific compiler for linear algebra by Google that can accelerate TensorFlow models and JAX functions with potentially no source code changes.
💡
Conclusion
- We discuss the purpose of these frameworks in our workflow and why we use them instead of existing popular frameworks.
- We discussed how to create a simple data pipeline using TensorFlow Datasets and tf.data.Datasets for the CIFAR-10 dataset.
- We discussed how we can set up a Weights & Biases run and sync all the configs for our experiment to be referred to anytime for reproducibility.
- We discussed how to utilize the linen API to define a very simple convolution-based neural network for image classification.
- We discussed the general anatomy of a machine-learning training loop and compared how they are written in existing popular frameworks.
- We then discussed in detail how to write a training and validation loop for JAX and Flax for our image classification task and track the metrics using Weights & Biases.
- We got our hands dirty building a training and validation pipeline for an image classification model in JAX and Flax.
- We finally discussed the significance of just-in-time compilation in JAX and how it enables us to write accelerator-agnostic code.
Add a comment
official docs: Broken link
Reply
official docs: outdated link
Reply
__init__ Shouldn't this be replaced with `init` to be more accurate? It seems that `def __init__` is not explicit in a linen module; instead, `def setup` is specified. If my understanding is incorrect, please feel free to correct me.
Reply
accumulate_metrics
Maybe I'm missing something but is this function defined?
Reply
train_and_evaluate
An example of how this function is supposed to be consumed would be helpful. Particularly, I guess it starts with a random or some pre-defined state and returns the state with which I can retrieve the trained params.
The readers would want to understand how to do this.
1 reply
Time for an Exercise...
Does the code by default run on a GPU if it's available?
1 reply
Evaluation Step
Nothing needed to explicitly instruct the model to run in inference mode?
1 reply
loss_fn
Why and what should this function be returning other than the loss value.
Say, if I return some intermediate feature maps as well as the logits, how it'd handle the subsequent changes?
1 reply
Markdown Panel
jnp.ndarray
We're passing in the TensorFlow datasets. Do they get automatically converted to this type? Spend some sentences discussing it. 1 reply
You would want to run the model in inference model (i.e., `training=False`) in `test_step()`.
Similarly, you wouldn't want to enable the computation of grads during evaluation in the PyTorch code.
1 reply
apply What does it do?
1 reply
x = jnp.ones(shape=(config.batch_size, 32, 32, 3)) # Dummy Input
What if I have variable-length resolutions?
5 replies
pool_module: Callable = nn.avg_pool
Why assign members this way?
Reply
@nn.compact
Why?
Reply
JAX transformations such as vmap, remat or scan
Hyperlink them with their respective official docs.
Reply
Markdown Panel
What does the above run collection denote? Reply
Wouldn't
Provide official references in this section that show how to do distributed training in both the frameworks.
Reply
scaling
Provide a really shining reference.
Reply
Jax allows us to use composable transforms to transform a function
to transform a function into?
Reply
Tensorflow -> TensorFlow?
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.