Skip to main content

Google's New LLM Training Stack

Google has a new set of tools for those looking to train massive LLM's!
Created on November 13|Last edited on November 13
The recent advancements in generative AI have led to the development of large language models with an unprecedented scale, utilizing immense computational resources. Training these models requires a staggering amount of computing power, measured in exa-FLOPs, and involves handling tens of thousands of AI accelerator chips.

Leveraging TPU Multislice Training

To tackle the challenges of managing such a vast array of chips and optimizing the training process, Google Cloud has introduced Cloud TPU Multislice Training. This solution aims to streamline the orchestration, compilation, and overall optimization of machine learning training. It is designed to be scalable, reliable, and user-friendly, supporting TPU v4 and v5e. The Multislice approach allows for efficient training of ML models on a large scale using Google’s Cloud TPUs.
Key features of Cloud TPU Multislice Training include robust orchestration for scaling model training across thousands of TPU chips, performant compilation through the XLA compiler, and a flexible stack that supports popular ML frameworks and a variety of model architectures. This full-stack training offering enhances the efficiency of large-scale distributed training. Furthermore, the training utilizes several tools and techniques, such as the Accelerated Processing Kit (XPK) for ML cluster and job orchestration, MaxText for scalable LLM implementation, and Accurate Quantized Training (AQT) for efficient training using reduced numerical precision. These components are part of the broader JAX training stack, which was central to setting up the distributed training job.

A World Record?

Google Cloud TPU executed what is believed to be the world’s largest publicly disclosed LLM distributed training job, using over 50,000 TPU v5e chips. This achievement surpassed the computing power of the top supercomputers, demonstrating the capability of Cloud TPU Multislice Training in handling large-scale LLM training. The training involved models of various sizes and used data parallelism techniques across the TPU pods. Key optimizations in orchestration, performance, and storage were necessary to manage this massive-scale operation, including innovations in job orchestration, compiler optimization, and efficient interaction with storage. The scalability results showed impressive outcomes, with high utilization rates and strong scaling across multiple TPU pods.

Challenges and Solutions

The Google Cloud team faced significant challenges in managing storage for their large-scale distributed training of language models. A few challenges revolved around efficiently interacting with persistent storage, particularly when dealing with data loading and checkpointing at the scale of their 199-pod cluster.

Data Loading Challenges
Initially, as the size of the compute cluster increased, especially beyond 64 pods, the team noticed a decline in performance due to the strain on Google Cloud Storage (GCS). This was primarily because loading data from GCS to such a large number of pods simultaneously put immense pressure on the storage system.

Solution:
To resolve this, the team implemented a distributed data loading strategy. This approach involved having only a subset of hosts responsible for loading data, rather than all pods trying to access GCS simultaneously. This strategy significantly alleviated the pressure on GCS, ensuring more stable and efficient data loading across the cluster.

Checkpointing Challenges
Another major issue was related to checkpointing. In typical scenarios, each data parallel replica would load the full checkpoint from GCS. For a large model (e.g., 128B parameters), this approach meant loading a checkpoint of about 1.536 TB for each of the 199 pods, amounting to a massive total bandwidth requirement. This process was not only bandwidth-intensive but also time-consuming, taking about 40 minutes for a complete load, which was far from ideal for efficient training operations.

Solution:
To tackle this, Google Cloud’s team introduced an innovative feature allowing a single pod to load the entire checkpoint and then broadcast it to other replicas across the cluster. This method leverages the flexibility of the JAX framework. By doing so, the initial loading of the checkpoint into a single pod would take approximately 12 seconds, and broadcasting the optimizer state to other pods would add about 4 seconds. This approach reduced the total time to just 16 seconds, achieving a remarkable 150x speedup in the checkpoint loading process.
Furthermore, for writing checkpoints, a similar optimization was applied. A single leader replica was designated to write the entire checkpoint, thus preventing high query rates to GCS and improving the efficiency of the checkpointing process.

Impact of These Solutions

These storage optimizations had a profound impact on the efficiency and scalability of the training process. By reducing the data loading and checkpointing times, the Google Cloud team was able to significantly increase the throughput and performance of their large-scale ML training operations. These solutions exemplify the kind of innovative thinking necessary to manage the complexities of training AI models at an unprecedented scale.



To manage the orchestration of over 50,000 accelerator chips for AI training, Google Cloud implemented a sophisticated solution using Google Kubernetes Engine (GKE). This involved using GKE’s Jobset and Kueue features to enable the submission of both small and large-scale training jobs. The team optimized the handling of a vast number of virtual machines by efficiently managing internal IP addresses, and precaching Docker images. Additionally, the performance of the system was enhanced using XLA (Accelerated Linear Algebra), a compiler that optimizes workloads for machine learning accelerators. XLA employs SPMD (single program, multiple data) and GSPMD (Generalized Single Program, Multiple Data) techniques for efficient parallel computing, which were crucial for managing the extensive computations across the cluster. These comprehensive solutions enabled efficient orchestration and high-throughput scheduling for training at an unprecedented scale.

More Improvements to Come

Looking ahead, Google Cloud aims to further enhance start-up times, scaling efficiency, and overall performance of the Cloud TPU Multislice Training. This effort is part of a larger commitment to advancing the capabilities of generative AI and LLMs. The project, a result of efforts within Google Cloud demonstrates a significant step forward in distributed ML training.