How To Create an Image Classification Model in JAX/Flax
In this article, we learn how to create a simple image classification model in Flax with a short tutorial complete with code and interactive visualizations.
Created on June 21|Last edited on January 24
Comment
In this article, we'll look at how you can create a simple image classification model in JAX/Flax. We'll use the CIFAR-10 dataset for simplicity and instead focus on exploring the API and understanding how to write training loops and model instantiations.
Flax is being actively improved and has a growing community of researchers and engineers at Google who happily use Flax for their daily research. It's specially designed for JAX and offers a lot of flexibility because of its abstraction-based build.
Here's what we'll be covering:
Table of Contents
CodeDefining Model ArchitectureThe Loss FunctionEvaluation MetricsPropagationSummaryRecommend Reading
Code
Defining Model Architecture
The core of Flax is module abstraction. Modules allow us to write parameterized functions just like writing a normal NumPy function with JAX. The Module API allows us to declare parameters and use them directly with the JAX APIs.
Here, we create a simple convolutional neural network (CNN) module for our classification model. A Module is created by defining a subclass of flax.nn.Module and implementing the apply method.
Our custom CNN module has the following layers:
- Two convolutional blocks followed by ReLU non-linearity and average pooling.
- We then flatten the output of the last convolutional block.
- We then add a simple dense layer with 256 units followed by ReLU non-linear activation.
- CIFAR10 has 10 classes. Therefore, we create a last dense layer with 10 units followed by the log softmax activation function.
import flax.linen as nnclass CNN(nn.Module):def apply(self, x):x = nn.Conv(x, features=32, kernel_size=(3, 3))x = nn.relu(x)x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))x = nn.Conv(x, features=64, kernel_size=(3, 3))x = nn.relu(x)x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))x = x.reshape((x.shape[0], -1))x = nn.Dense(x, features=256)x = nn.relu(x)x = nn.Dense(x, features=10)x = nn.log_softmax(x)return x
The Loss Function
Here, we implement the Cross-Entropy Loss function with two parameters, namely logits and label. You can notice that this function is meant to process single samples rather than matrices compared to other deep learning frameworks. This way, we don't have to think much about dimensionality while writing custom loss functions.
The @jax.vmap decorator is an inbuilt JAX decorator which automatically vectorizes our loss function to work with batch processing.
@jax.vmapdef cross_entropy_loss(logits, label):return -logits[label]
Evaluation Metrics
Here, we create a simple compute_metrics function where we take the mean of the output of the cross-entropy loss and then calculate the accuracy of our model.
def compute_metrics(logits, labels):loss = jnp.mean(cross_entropy_loss(logits, labels))accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)return {'loss': loss, 'accuracy': accuracy}
Propagation
Here, we essentially create the propagation part of our model:
- We run our image through the model by passing it as batch['image'].
- Calculate the loss by passing the output through the cross_entropy_loss function.
- Calculate the gradient using the jax.grad function.
- "Apply" this gradient to our optimizer using the apply_gradient function.
The @jax.jit decorator compiles our function into fused device operations, which can then run efficiently on GPUs or TPUs.
@jax.jitdef train_step(optimizer, batch):def loss_fn(model):logits = model(batch['image'])loss = jnp.mean(cross_entropy_loss(logits, batch['label']))return lossgrad = jax.grad(loss_fn)(optimizer.target)optimizer = optimizer.apply_gradient(grad)return optimizer
Here, we create a simple eval() function which returns the metrics from the compute_metrics function.
@jax.jitdef eval(model, eval_ds):logits = model(eval_ds['image'] / 255.0)return compute_metrics(logits, eval_ds['label'])
Summary
In this article, you saw how you can create a simple Image Classification Model in JAX/Flax for writing accelerator agnostic training loops. To see the full suite of W&B features, please check out this short 5 minutes guide.
If you want more reports covering the math and from-scratch code implementations, let us know in the comments below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental development topics like GPU Utilization and Saving Models.
Recommend Reading
Preventing The CUDA Out Of Memory Error In PyTorch
A short tutorial on how you can avoid the "RuntimeError: CUDA out of memory" error while using the PyTorch framework.
How to Initialize Weights in PyTorch
A short tutorial on how you can initialize weights in PyTorch with code and interactive visualizations.
How to Compare Keras Optimizers in Tensorflow for Deep Learning
A short tutorial outlining how to compare Keras optimizers for your deep learning pipelines in Tensorflow, with a Colab to help you follow along.
How To Use GPU with PyTorch
A short tutorial on using GPUs for your deep learning models with PyTorch, from checking availability to visualizing usable.
PyTorch Dropout for regularization - tutorial
Learn how to regularize your PyTorch model with Dropout, complete with a code tutorial and interactive visualizations
How to save and load models in PyTorch
This article is a machine learning tutorial on how to save and load your models in PyTorch using Weights & Biases for version control.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.