Skip to main content

Fine-tuning Gemma with JAX

Created on October 25|Last edited on April 9
Fine-tuning Gemma with JAX enables you to adapt a high-performance language model to your specific dataset using fast and scalable infrastructure. In this guide, we walk through the process of downloading the Gemma model from Kaggle, preparing a translation dataset, tokenizing the text, batching with padding and masking, and logging your dataset using Weights & Biases for reproducibility and reuse. We'll use OPUS Books (en-fr) as our translation dataset and Flax/JAX for training.

Table of contents




The Gemma model family



Fine-tuning Gemma models with JAX

The first thing we need to do is get access to the Gemma model weights. For this model family, in particular, all model weights are distributed on Kaggle. Go to kaggle.com, set up an account, and accept the Gemma use policy and license terms. Once you complete the consent form and accept the terms and conditions, you can see instructions on using the model with various frameworks such as JAX, Keras, or PyTorch.
To use Gemma on a Colab, we need to add some environment variables to tell Kaggle about our account information so that it can verify whether we have accepted the conditions. To do this, generate a new token from your account page on Kaggle. This should download a JSON file with two keys: username and key. Now, we can create Colab secrets, namely KAGGLE_USERNAME and KAGGLE_KEY.
Similarly, generate an API Token for your weights and biases account and add it as a secret; I named mine W&B. Wrapping it all together, we can now set the environment variables in our Colab Notebook as follows:
## Set environment variables
import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["WANDB_API_KEY"] = userdata.get('W&B')
NOTE: You might be prompted with Grant Access? dialog box. Simply allow the notebook to access the variables.
💡
Having set the environment variables, we can simply download the Gemma weights using the kagglehub package as:
## Download Gemma Model from Kaggle Hub
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
For this tutorial, we will use the OPUS Books dataset released by the language technology research group at the University of Helsinki (Helsinki-NLP). This dataset offers multiple languages, but we'll particularly use the "en-fr" subset to fine-tune an English-French translation model. The pre-processing workflow that we will follow in this tutorial should work for any translation dataset.
from datasets import load_dataset

hf_ds = load_dataset("opus_books", "en-fr", split="train")
hf_ds = hf_ds.with_format("tf", columns=["translation"], output_all_columns=True)
hf_ds = hf_ds.train_test_split(test_size=0.1)
We shall follow and adapt the pre-processing pipeline from the Gemma Docs and start out by building a Custom Tokenizer that adapts the raw Gemma tokenizer to our specific dataset.
class CustomTokenizer:
def __init__(self, raw_tokenizer):
self._spm_processor = raw_tokenizer

@property
def pad_id(self) -> int:
return self._spm_processor.pad_id()

def tokenize(self,
example: str | bytes,
prefix: str = '',
suffix: str = '',
add_eos: bool = True
) -> jax.Array:
# Handle strings
if isinstance(example, bytes):
example = example.decode('utf-8')
elif isinstance(example, (tf.Tensor, np.ndarray)):
example = example.numpy().decode('utf-8') if hasattr(example, 'numpy') else str(example)
else:
example = str(example)
prefix = str(prefix)
suffix = str(suffix)
full_text = prefix + example + suffix

# Create token list starting with BOS token
int_list = [self._spm_processor.bos_id()]

# Add tokens for input string
int_list.extend(self._spm_processor.EncodeAsIds(full_text))

# If needed add EOS token
if add_eos:
int_list.append(self._spm_processor.eos_id())

return jnp.array(int_list, dtype=jnp.int32)

def tokenize_tf_op(
self,
str_tensor: tf.Tensor,
prefix: str = '',
suffix: str = '',
add_eos: bool = True
) -> tf.Tensor:
encoded = tf.numpy_function(
self.tokenize,
[str_tensor, prefix, suffix, add_eos],
tf.int32)
encoded.set_shape([None])
return encoded

def to_string(self, tokens: jax.Array) -> str:
return self._spm_processor.EncodeIds(tokens.tolist())
Some few things to note here:
  • We add a common prefix to each input that signals the translation task.
  • We add a suffix at the end of each prompt to signal the model to start translating.
  • Gemma models expect a "beginning of sequence" token at the beginning of each sequence, so we add the default bos and eos ID before and after each sample.
Now, let's load the raw tokenizer and adapt it to our dataset
# Load the Gemma tokenizer
raw_tokenizer = spm.SentencePieceProcessor()
raw_tokenizer.Load(TOKENIZER_PATH)

# Create a instance of our Custom Tokenizer
tokenizer = CustomTokenizer(raw_tokenizer)

# Define wrapper functions
def tokenize_source(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(
example,
prefix='Translate this into French:\n',
suffix='\n',
add_eos=False
)

def tokenize_destination(tokenizer, example: tf.Tensor):
return tokenizer.tokenize_tf_op(
example,
add_eos=True
)

# Tokenize Train and Test splits
train_ds = hf_ds["train"].map(
lambda x: {
"src": tokenize_source(tokenizer, x["translation"]["en"]),
"dst": tokenize_destination(tokenizer, x["translation"]["fr"])
},
desc = "Applying Tokenization to Training Dataset",
remove_columns=["translation"],
)

test_ds = hf_ds["test"].map(
lambda x: {
"src": tokenize_source(tokenizer, x["translation"]["en"]),
"dst": tokenize_destination(tokenizer, x["translation"]["fr"])
},
desc = "Applying Tokenization to Testing Dataset",
remove_columns=["translation"],
)
Having tokenized our dataset, we now need to create dataloaders. This step involves padding to ensure consistent samples per batch, filtering by sample length and batching and shuffling our samples.
class DatasetBuilder:
def __init__(self,
tokenizer : CustomTokenizer,
max_seq_len: int):
self._tokenizer = tokenizer
self._base_data = {
"train": train_ds,
"test": test_ds,
}
self._max_seq_len = max_seq_len

def _pad_up_to_max_len(
self,
input_tensor: tf.Tensor,
pad_value: int | bool,
) -> tf.Tensor:
seq_len = tf.shape(input_tensor)[0]
to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
return tf.pad(input_tensor,
[[0, to_pad]],
mode='CONSTANT',
constant_values=pad_value,
)

def _to_training_input(self, example) -> TrainingInput:
# Extract source and destination tokens from the dictionary
src_tokens = example['src']
dst_tokens = example['dst']
# The input sequence fed to the model is simply the concatenation of the
# source and the destination.
tokens = tf.concat([src_tokens, dst_tokens], axis=0)

# To prevent the model from updating based on the source (input)
# tokens, add a target mask to each input.
q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
mask = tf.concat([q_mask, a_mask], axis=0)

# If the output tokens sequence is smaller than the target sequence size,
# then pad it with pad tokens.
tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

# Don't want to perform the backward pass on the pad tokens.
mask = self._pad_up_to_max_len(mask, False)

return TrainingInput(input_tokens=tokens, target_mask=mask)


def get_train_dataset(self, batch_size: int):
ds = self._base_data["train"]
# Process samples
ds = ds.map(
self._to_training_input,
desc="Process samples",
remove_columns=["src", "dst"],
)

# Remove the samples that are too long.
ds = ds.filter(
lambda x: tf.shape(x["input_tokens"])[0] <= self._max_seq_len,
desc = "Removing Samples that are too long",
)

# Shuffle the dataset.
ds = ds.shuffle(seed=42)

# Build batches.
ds = ds.batch(
batch_size,
drop_last_batch=True,
)
return ds

def get_validation_dataset(self, batch_size: int):
ds = self._base_data["test"]
# Process samples
ds = ds.map(
self._to_training_input,
desc="Process samples",
remove_columns=["src", "dst"],
)

# Remove the samples that are too long.
ds = ds.filter(
lambda x: tf.shape(x["input_tokens"])[0] <= self._max_seq_len,
desc = "Removing Samples that are too long",
)

# Build batches.
ds = ds.batch(
batch_size,
drop_last_batch=True,
)
return ds
This can be a time-consuming process; moreover, we shouldn't spend precious GPU compute power on data handling. Thus, it makes sense to perform all of this pre-processing once, save our dataset, change our device to an accelerator, and simply download and use the processed dataset.
For this, we will use Weights and Biases artifacts. Artifacts can be used to track everything from data to models and evaluations. After generating our dataset, let's serialise them to disk and upload them.
# Process Raw Dataset
dataset_builder = DatasetBuilder(tokenizer, max_seq_len=20)
train_ds = dataset_builder.get_train_dataset(TRAIN_BATCH_SIZE)
valid_ds = dataset_builder.get_validation_dataset(VALID_BATCH_SIZE)

# Serialize processed dataset
train_ds.save_to_disk("artifacts/train_ds/")
valid_ds.save_to_disk("artifacts/valid_ds/")

# Start a wandb run
run = wandb.init(project="gemma", job_type="upload")

# Generate Artifacts for training and validation
train_ds_artifact = wandb.Artifact(name="opus_en_fr_train", type="dataset")
valid_ds_artifact = wandb.Artifact(name="opus_en_fr_valid", type="dataset")

# Add local dir to artifact
train_ds_artifact.add_dir("artifacts/train_ds/")
valid_ds_artifact.add_dir("artifacts/valid_ds/")

# Upload to wandb
run.log_artifact(train_ds_artifact)
run.log_artifact(valid_ds_artifact)
run.finish()
The following datasets are now available to download and use for training runs as follows:
# Download pre-processed dataset
api = wandb.Api()
train_ds_artifact = api.artifact("sauravmaheshkar/gemma/opus_en_fr_train:v0")
valid_ds_artifact = api.artifact("sauravmaheshkar/gemma/opus_en_fr_valid:v0")

valid_ds_dir = valid_ds_artifact.download()
train_ds_dir = train_ds_artifact.download()

# Load from memory
from datasets import load_from_disk
train_ds = load_from_disk(train_ds_dir)
valid_ds = load_from_disk(valid_ds_dir)

# Load as jax arrays
train_ds = train_ds.with_format("jax")
valid_ds = valid_ds.with_format("jax")
Having processed our dataset, let's now look at the training code.

Conclusion