Training a Protein Language Model on Cloud TPU
End-to-end example with Jax, Flax, Optax and Weights & Biases
Created on October 13|Last edited on November 30
Comment
Introduction
In this report we'll demonstrate how to train a protein language model on a Google TPU using Jax/Flax/Optax with Weights & Biases.
First, we'll do a brief introduction to proteins and language models. To train a language model from scratch, we need decent compute - we will discuss the benefits of choosing Google TPU as your compute platform. Then, we'll briefly introduce Jax, Flax and Optax frameworks. We will review some choices of tokenization schemes, model architecture and hyperparameters, followed by model evaluation.
What is a Protein Language Model?
A protein language model is a machine learning model that is specifically designed to predict masked or incomplete parts of protein sequences. This type of model is useful because it is able to learn the underlying patterns and structure of a protein "language" in order to make educated guesses about missing or masked data.
By learning about the underlying "language" of protein sequences, these models are able to extract useful information that can be applied to other tasks, such as predicting protein structure, function, stability, and other properties. In this way, protein language models can be fine-tuned to perform more advanced tasks, ultimately leading to better predictions of protein function and other downstream tasks.
Still confused? Let's review this step by step:
- Proteins are complex biological molecules composed of chains of amino acids. They play a vital role in living organisms and can serve a wide range of functions. Alpha-amino acids, a subset of amino acids, serve as the building blocks of proteins and can be thought of as the "vocabulary" of the protein language. There are 22 alpha-amino acids in the genetic code, which are often represented by single letters (e.g. G for glycine and A for alanine).
- The sequence of amino acids in a protein determines its 3D shape, which is crucial for its function. Predicting protein shape from sequence is a challenging and active area of research (see, for example, Alphafold).
- Language models were originally developed in natural language processing (NLP) as models that are trained on large amounts of unlabeled text by masking certain words in sentences and then asking the model to predict the missing words. These models can then be fine-tuned for other tasks, such as sentiment classification, named entity recognition, or question answering.
The table below shows examples of proteins, their amino acid sequences, and 3D structure visualizations:
Run: toasty-sky-88
1
Now, the use of language modeling techniques for protein sequences is not a new concept - several models, such as ESM, ProGen or ProteinBert, have already been published. However, compared to the rapid pace of innovation and publication in natural language processing (NLP), there is still much room for exploration and experimentation in this field.
Applying language modeling approaches to protein sequences can be a powerful tool for learning their representations, as well as for fine-tuning models to perform more advanced tasks such as predicting protein structure, function, and stability. Despite the progress that has been made so far, there is still much potential for further development and refinement of these models.
Table of Contents
IntroductionWhat is a Protein Language Model?Table of ContentsThe Benefits of Using Cloud TPUsIntro to Flax, Jax and OptaxModel ArchitectureTokenizationHyperparameter SearchTrainingEvaluationConclusionRelated Reading on Protein Research:
The Benefits of Using Cloud TPUs
Cloud TPU (or Tensor Processing Unit) is a custom hardware developed by Google specifically for machine learning acceleration. It supports PyTorch, TensorFlow and Jax. The Cloud TPU hardware and software stack, which is based on the XLA compiler, includes specific optimizations to deliver efficiency and optimized performance. For example, Cloud TPU pod slices (i.e., multi-node Cloud TPU clusters) are equipped with chip-to-chip high-bandwidth links (Cloud TPU v4 has 6Tbps of per-host bandwidth) for better scaling.

Additionally, the XLA compiler allows users to express SPMD-based (Single Program Multiple Data) model parallel configurations for large models. With APIs such as PJIT (Jax), users can easily leverage SPMD to express a range of model and data parallelisms with just a few lines of code. The XLA compiler also offers out-of-the-box optimizations, such as compute and communication overlaps, to further boost the scaling efficiency of Cloud TPU systems for large models.
Unlike other hardware accelerators, such as GPUs, the Cloud TPU software stack is based on a compiler rather than a library-based approach. This compiler-based approach enables high performance and utilization on the TPU for a variety of workloads.
Cloud TPUs have demonstrated efficiency at scale for large models, including transformer models (for vision, text, speech, or multimodal tasks), recommendation systems, generative models, and distributed reinforcement learning workloads. Examples of this efficiency include the Pathways Language Model and Compute Optimal Large Language Models. These advantages are further demonstrated in the ML perf 2.0 benchmarks. Google has also published a scaling performance benchmark for large language models in its Cloud TPUv4 announcement.
Intro to Flax, Jax and Optax
When selecting a framework for developing and training our model on TPUs, we have the choice between TensorFlow, PyTorch, and JAX. For this project, we have chosen JAX.
JAX is a library that supports high-performance numerical computing on accelerators. It allows for the composition of NumPy-like operations and automatically differentiates them with Autograd. It also vectorizes and parallelizes these operations, compiling them into XLA-optimized kernels, making it a great fit for machine learning research on TPUs.
To make our work easier, we will use higher-level frameworks built on top of JAX - Flax and Optax. Flax is a neural network library and ecosystem, and we will use Flax modules and submodules to construct our model. Optax is a library for composable gradient transformations, which will help us perform parameter optimization within our training loop.
To speed up our progress, we will utilize starter code from the Flax/JAX community week. You can also learn more about JAX and Flax from the linked reports below.
Writing a Training Loop in JAX and Flax
In this article, we explore an end-to-end training and evaluation pipeline in JAX, Flax, and Optax for image classification, using W&B to track experiments.
How To Create an Image Classification Model in JAX/Flax
In this article, we learn how to create a simple image classification model in Flax with a short tutorial complete with code and interactive visualizations.
Model Architecture
The natural language processing (NLP) community has greatly benefited from a variety of transformer language model architectures, including BERT, RoBERTa, BART, T5, GPT, DeBERTa, and others. We would be interested in seeing research that investigates some of the design choices made in these NLP models in the context of protein language and tasks.
One area to explore is pretext tasks. In causal language modeling (GPT2), the model is asked to predict the next token in a sequence without being able to see what comes next. Masked language modeling (BERT, RoBERTa - see illustration below) allows the model to see the entire sequence, but some tokens are masked and the model must predict the words (or amino acids) that were masked.
Replaced token detection (ELECTRA, DeBERTa) is a more efficient pre-training task where some tokens in the input are replaced with plausible alternatives, and the model is trained to predict whether each token in the input was replaced or not. In a sequence-to-sequence (BART, T5) task, the model is trained to reconstruct the original sequence given a corrupted input.

Another interesting area of research is altering the attention mechanism. Traditional language model architectures often have a maximum length limitation (typically 512 tokens), which can make it difficult to model long sequences. Architectures such as BigBird and Longformer have been developed to address this limitation. DeBERTa architecture has also introduced disentangled attention, which treats word content and position separately - this approach may also be applicable to protein language models.
Tokenization
Tokenization is a crucial step in language modeling - it involves dividing sequences into tokens and representing them numerically (in NLP, this is frequently a word or part of a word). The default approach for protein language models is to treat each amino acid residue as a separate token (similar to character language models in natural language). However, it may be worth considering increasing the vocabulary by also treating frequently occurring n-grams as separate "words". This could allow us to more accurately represent sequence motifs.
You can experiment with different tokenization approaches using this notebook from our example repository.
💡
In this article, we will use a standard masked language modeling architecture (RoBERTa) and the SentencePieceBPETokenizer to train a new tokenizer on protein sequences. We will use a vocabulary size of 1000, but it's worth testing different vocabulary sizes to see how it impacts downstream tasks.
Hyperparameter Search
When it comes to choosing the hyperparameters for our training, there are a few different approaches we can take. One option is to schedule several training runs on a sample from our target dataset, experimenting with different values of the hyperparameter in question. This is what we did in this case, testing different values of the learning rate to see how it affected the convergence of the training and evaluation losses.
As we can see from the results, the training and evaluation losses converged faster as we increased the learning rate up until a value of 1e-4. Beyond this point, the convergence slowed down significantly. We visualized these experiments in the W&B charts below for easy reference:
Run set
6
Based on these results, we have decided to use a learning rate of 1e-4 for the final training run. This value seems to provide a good balance between convergence speed and overall performance, and it should help us achieve the best possible results with our model.
Training
After experimenting with hyperparameter values on a sample dataset, it's time for a full training run! This is a good time to introduce our dataset.
As background, scientific knowledge bases such as UniProt and UniParc contain known protein sequences. The UniRef was created by clustering these sequences to remove any redundant entries. We will be using a version of UniRef hosted on the HuggingFace Hub, which has already been split into train, validation, and test subsets for us.
As we train our model, we can see that both the training and evaluation losses are decreasing, while the accuracy is increasing. It may be worthwhile to continue training the model for a longer period of time.
Run: stilted-firefly-116
1
Would it be a crazy idea to train your own protein language model? If you're interested in giving it a try, you can get started by cloning and following the instructions in our example repository.
💡
Evaluation
As we are training a foundation model for a pre-text task, we are focusing on minimizing both training and evaluation loss, as these metrics should be correlated with the model's performance on downstream tasks. Another metric used in language modeling is perplexity, which measures how well the model is able to predict the test set (a perfect model would have a perplexity of 1).
However, it is important to note that these metrics can be influenced by the tokenization scheme used. It is possible to compare models that use the same tokenizer, but caution should be taken when comparing models with very different vocabulary sizes (e.g. a model with a vocabulary of 22 amino acids versus one with a vocabulary of 1000 subsequences).
Run: stilted-firefly-116
1
A more reliable way to assess the quality of a language model is to evaluate its performance on downstream tasks, such as fine-tuning it to predict protein function, 2D or 3D structure, or biochemical properties like stability. This is an important topic and probably deserves a separate report.
Conclusion
In this article, we described the process of training a protein language model using Cloud TPU with the help of JAX, Flax, Optax, and Weights & Biases. We believe that there is a vast and largely untapped potential for research and application in this field. By leveraging powerful and efficient computing resources and frameworks, as well as an experiment management platform, we hope to facilitate new discoveries in the field of protein language modeling.
Explore the code used in this article: https://github.com/wandb/examples/tree/master/examples/jax/jax-llm
💡
Related Reading on Protein Research:
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.