Skip to main content

How To Create an Image Classification Model in JAX/Flax

In this article, we learn how to create a simple image classification model in Flax with a short tutorial complete with code and interactive visualizations.
Created on June 21|Last edited on January 24
In this article, we'll look at how you can create a simple image classification model in JAX/Flax. We'll use the CIFAR-10 dataset for simplicity and instead focus on exploring the API and understanding how to write training loops and model instantiations.
Flax is being actively improved and has a growing community of researchers and engineers at Google who happily use Flax for their daily research. It's specially designed for JAX and offers a lot of flexibility because of its abstraction-based build.
Here's what we'll be covering:

Table of Contents





Code

Defining Model Architecture

The core of Flax is module abstraction. Modules allow us to write parameterized functions just like writing a normal NumPy function with JAX. The Module API allows us to declare parameters and use them directly with the JAX APIs.
Here, we create a simple convolutional neural network (CNN) module for our classification model. A Module is created by defining a subclass of flax.nn.Module and implementing the apply method.
Our custom CNN module has the following layers:
  • Two convolutional blocks followed by ReLU non-linearity and average pooling.
  • We then flatten the output of the last convolutional block.
  • We then add a simple dense layer with 256 units followed by ReLU non-linear activation.
  • CIFAR10 has 10 classes. Therefore, we create a last dense layer with 10 units followed by the log softmax activation function.
import flax.linen as nn

class CNN(nn.Module):
def apply(self, x):
x = nn.Conv(x, features=32, kernel_size=(3, 3))
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(x, features=64, kernel_size=(3, 3))
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(x, features=256)
x = nn.relu(x)
x = nn.Dense(x, features=10)
x = nn.log_softmax(x)
return x


The Loss Function

Here, we implement the Cross-Entropy Loss function with two parameters, namely logits and label. You can notice that this function is meant to process single samples rather than matrices compared to other deep learning frameworks. This way, we don't have to think much about dimensionality while writing custom loss functions.
The @jax.vmap decorator is an inbuilt JAX decorator which automatically vectorizes our loss function to work with batch processing.
@jax.vmap
def cross_entropy_loss(logits, label):
return -logits[label]

Evaluation Metrics

Here, we create a simple compute_metrics function where we take the mean of the output of the cross-entropy loss and then calculate the accuracy of our model.
def compute_metrics(logits, labels):
loss = jnp.mean(cross_entropy_loss(logits, labels))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return {'loss': loss, 'accuracy': accuracy}

Propagation

Here, we essentially create the propagation part of our model:
  • We run our image through the model by passing it as batch['image'].
  • Calculate the loss by passing the output through the cross_entropy_loss function.
  • Calculate the gradient using the jax.grad function.
  • "Apply" this gradient to our optimizer using the apply_gradient function.
The @jax.jit decorator compiles our function into fused device operations, which can then run efficiently on GPUs or TPUs.
@jax.jit
def train_step(optimizer, batch):
def loss_fn(model):
logits = model(batch['image'])
loss = jnp.mean(cross_entropy_loss(
logits, batch['label']))
return loss
grad = jax.grad(loss_fn)(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return optimizer
Here, we create a simple eval() function which returns the metrics from the compute_metrics function.
@jax.jit
def eval(model, eval_ds):
logits = model(eval_ds['image'] / 255.0)
return compute_metrics(logits, eval_ds['label'])

Summary

In this article, you saw how you can create a simple Image Classification Model in JAX/Flax for writing accelerator agnostic training loops. To see the full suite of W&B features, please check out this short 5 minutes guide.
If you want more reports covering the math and from-scratch code implementations, let us know in the comments below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental development topics like GPU Utilization and Saving Models.

Recommend Reading


Iterate on AI agents and models faster. Try Weights & Biases today.