Fine-tuning Gemma with JAX
Created on October 25|Last edited on April 9
Comment
Fine-tuning Gemma with JAX enables you to adapt a high-performance language model to your specific dataset using fast and scalable infrastructure. In this guide, we walk through the process of downloading the Gemma model from Kaggle, preparing a translation dataset, tokenizing the text, batching with padding and masking, and logging your dataset using Weights & Biases for reproducibility and reuse. We'll use OPUS Books (en-fr) as our translation dataset and Flax/JAX for training.
Table of contents
The Gemma model family
Fine-tuning Gemma models with JAX
The first thing we need to do is get access to the Gemma model weights. For this model family, in particular, all model weights are distributed on Kaggle. Go to kaggle.com, set up an account, and accept the Gemma use policy and license terms. Once you complete the consent form and accept the terms and conditions, you can see instructions on using the model with various frameworks such as JAX, Keras, or PyTorch.
To use Gemma on a Colab, we need to add some environment variables to tell Kaggle about our account information so that it can verify whether we have accepted the conditions. To do this, generate a new token from your account page on Kaggle. This should download a JSON file with two keys: username and key. Now, we can create Colab secrets, namely KAGGLE_USERNAME and KAGGLE_KEY.
Similarly, generate an API Token for your weights and biases account and add it as a secret; I named mine W&B. Wrapping it all together, we can now set the environment variables in our Colab Notebook as follows:
## Set environment variablesimport osfrom google.colab import userdataos.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')os.environ["WANDB_API_KEY"] = userdata.get('W&B')
NOTE: You might be prompted with Grant Access? dialog box. Simply allow the notebook to access the variables.
💡
Having set the environment variables, we can simply download the Gemma weights using the kagglehub package as:
## Download Gemma Model from Kaggle Hubimport kagglehubGEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
For this tutorial, we will use the OPUS Books dataset released by the language technology research group at the University of Helsinki (Helsinki-NLP). This dataset offers multiple languages, but we'll particularly use the "en-fr" subset to fine-tune an English-French translation model. The pre-processing workflow that we will follow in this tutorial should work for any translation dataset.
from datasets import load_datasethf_ds = load_dataset("opus_books", "en-fr", split="train")hf_ds = hf_ds.with_format("tf", columns=["translation"], output_all_columns=True)hf_ds = hf_ds.train_test_split(test_size=0.1)
We shall follow and adapt the pre-processing pipeline from the Gemma Docs and start out by building a Custom Tokenizer that adapts the raw Gemma tokenizer to our specific dataset.
class CustomTokenizer:def __init__(self, raw_tokenizer):self._spm_processor = raw_tokenizer@propertydef pad_id(self) -> int:return self._spm_processor.pad_id()def tokenize(self,example: str | bytes,prefix: str = '',suffix: str = '',add_eos: bool = True) -> jax.Array:# Handle stringsif isinstance(example, bytes):example = example.decode('utf-8')elif isinstance(example, (tf.Tensor, np.ndarray)):example = example.numpy().decode('utf-8') if hasattr(example, 'numpy') else str(example)else:example = str(example)prefix = str(prefix)suffix = str(suffix)full_text = prefix + example + suffix# Create token list starting with BOS tokenint_list = [self._spm_processor.bos_id()]# Add tokens for input stringint_list.extend(self._spm_processor.EncodeAsIds(full_text))# If needed add EOS tokenif add_eos:int_list.append(self._spm_processor.eos_id())return jnp.array(int_list, dtype=jnp.int32)def tokenize_tf_op(self,str_tensor: tf.Tensor,prefix: str = '',suffix: str = '',add_eos: bool = True) -> tf.Tensor:encoded = tf.numpy_function(self.tokenize,[str_tensor, prefix, suffix, add_eos],tf.int32)encoded.set_shape([None])return encodeddef to_string(self, tokens: jax.Array) -> str:return self._spm_processor.EncodeIds(tokens.tolist())
Some few things to note here:
- We add a common prefix to each input that signals the translation task.
- We add a suffix at the end of each prompt to signal the model to start translating.
- Gemma models expect a "beginning of sequence" token at the beginning of each sequence, so we add the default bos and eos ID before and after each sample.
Now, let's load the raw tokenizer and adapt it to our dataset
# Load the Gemma tokenizerraw_tokenizer = spm.SentencePieceProcessor()raw_tokenizer.Load(TOKENIZER_PATH)# Create a instance of our Custom Tokenizertokenizer = CustomTokenizer(raw_tokenizer)# Define wrapper functionsdef tokenize_source(tokenizer, example: tf.Tensor):return tokenizer.tokenize_tf_op(example,prefix='Translate this into French:\n',suffix='\n',add_eos=False)def tokenize_destination(tokenizer, example: tf.Tensor):return tokenizer.tokenize_tf_op(example,add_eos=True)# Tokenize Train and Test splitstrain_ds = hf_ds["train"].map(lambda x: {"src": tokenize_source(tokenizer, x["translation"]["en"]),"dst": tokenize_destination(tokenizer, x["translation"]["fr"])},desc = "Applying Tokenization to Training Dataset",remove_columns=["translation"],)test_ds = hf_ds["test"].map(lambda x: {"src": tokenize_source(tokenizer, x["translation"]["en"]),"dst": tokenize_destination(tokenizer, x["translation"]["fr"])},desc = "Applying Tokenization to Testing Dataset",remove_columns=["translation"],)
Having tokenized our dataset, we now need to create dataloaders. This step involves padding to ensure consistent samples per batch, filtering by sample length and batching and shuffling our samples.
class DatasetBuilder:def __init__(self,tokenizer : CustomTokenizer,max_seq_len: int):self._tokenizer = tokenizerself._base_data = {"train": train_ds,"test": test_ds,}self._max_seq_len = max_seq_lendef _pad_up_to_max_len(self,input_tensor: tf.Tensor,pad_value: int | bool,) -> tf.Tensor:seq_len = tf.shape(input_tensor)[0]to_pad = tf.maximum(self._max_seq_len - seq_len, 0)return tf.pad(input_tensor,[[0, to_pad]],mode='CONSTANT',constant_values=pad_value,)def _to_training_input(self, example) -> TrainingInput:# Extract source and destination tokens from the dictionarysrc_tokens = example['src']dst_tokens = example['dst']# The input sequence fed to the model is simply the concatenation of the# source and the destination.tokens = tf.concat([src_tokens, dst_tokens], axis=0)# To prevent the model from updating based on the source (input)# tokens, add a target mask to each input.q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)mask = tf.concat([q_mask, a_mask], axis=0)# If the output tokens sequence is smaller than the target sequence size,# then pad it with pad tokens.tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)# Don't want to perform the backward pass on the pad tokens.mask = self._pad_up_to_max_len(mask, False)return TrainingInput(input_tokens=tokens, target_mask=mask)def get_train_dataset(self, batch_size: int):ds = self._base_data["train"]# Process samplesds = ds.map(self._to_training_input,desc="Process samples",remove_columns=["src", "dst"],)# Remove the samples that are too long.ds = ds.filter(lambda x: tf.shape(x["input_tokens"])[0] <= self._max_seq_len,desc = "Removing Samples that are too long",)# Shuffle the dataset.ds = ds.shuffle(seed=42)# Build batches.ds = ds.batch(batch_size,drop_last_batch=True,)return dsdef get_validation_dataset(self, batch_size: int):ds = self._base_data["test"]# Process samplesds = ds.map(self._to_training_input,desc="Process samples",remove_columns=["src", "dst"],)# Remove the samples that are too long.ds = ds.filter(lambda x: tf.shape(x["input_tokens"])[0] <= self._max_seq_len,desc = "Removing Samples that are too long",)# Build batches.ds = ds.batch(batch_size,drop_last_batch=True,)return ds
This can be a time-consuming process; moreover, we shouldn't spend precious GPU compute power on data handling. Thus, it makes sense to perform all of this pre-processing once, save our dataset, change our device to an accelerator, and simply download and use the processed dataset.
For this, we will use Weights and Biases artifacts. Artifacts can be used to track everything from data to models and evaluations. After generating our dataset, let's serialise them to disk and upload them.
# Process Raw Datasetdataset_builder = DatasetBuilder(tokenizer, max_seq_len=20)train_ds = dataset_builder.get_train_dataset(TRAIN_BATCH_SIZE)valid_ds = dataset_builder.get_validation_dataset(VALID_BATCH_SIZE)# Serialize processed datasettrain_ds.save_to_disk("artifacts/train_ds/")valid_ds.save_to_disk("artifacts/valid_ds/")# Start a wandb runrun = wandb.init(project="gemma", job_type="upload")# Generate Artifacts for training and validationtrain_ds_artifact = wandb.Artifact(name="opus_en_fr_train", type="dataset")valid_ds_artifact = wandb.Artifact(name="opus_en_fr_valid", type="dataset")# Add local dir to artifacttrain_ds_artifact.add_dir("artifacts/train_ds/")valid_ds_artifact.add_dir("artifacts/valid_ds/")# Upload to wandbrun.log_artifact(train_ds_artifact)run.log_artifact(valid_ds_artifact)run.finish()
The following datasets are now available to download and use for training runs as follows:
# Download pre-processed datasetapi = wandb.Api()train_ds_artifact = api.artifact("sauravmaheshkar/gemma/opus_en_fr_train:v0")valid_ds_artifact = api.artifact("sauravmaheshkar/gemma/opus_en_fr_valid:v0")valid_ds_dir = valid_ds_artifact.download()train_ds_dir = train_ds_artifact.download()# Load from memoryfrom datasets import load_from_disktrain_ds = load_from_disk(train_ds_dir)valid_ds = load_from_disk(valid_ds_dir)# Load as jax arraystrain_ds = train_ds.with_format("jax")valid_ds = valid_ds.with_format("jax")
Having processed our dataset, let's now look at the training code.
Conclusion
Recommended Reading
Add a comment