Skip to main content

Multi-Token Prediction: A Promising New Idea For LLM’s

Teaching LLM's to look farther into the future?
Created on May 1|Last edited on May 1
In recent years, the development of large language models has revolutionized natural language processing, enabling machines to generate human-like text and perform a wide range of language understanding tasks. However, these models still face challenges in generating coherent and contextually appropriate text, especially when it comes to long-range dependencies and decision-making at critical choice points. Enter multi-token prediction—a novel approach that promises to unlock the full potential of language models by considering longer contexts during training.

Next Token Prediction

The standard approach for training large language models is to use a next-token prediction loss, where the model learns to predict the next token in a sequence given the previous tokens. However, the researchers argue that training models to predict multiple future tokens at once can result in higher sample efficiency and better downstream performance.

The Idea

The key idea is that at each position in the training corpus, instead of just predicting the next token, the model is tasked with predicting the following n tokens simultaneously using n independent output heads stacked on top of a shared model trunk. This multi-token prediction is treated as an auxiliary training task.
In their experiments on code and natural language models up to 13 billion parameters, the researchers found that multi-token prediction consistently outperformed the next-token prediction baselines on generative benchmarks, with models solving around 15% more code problems on average. The gains were increasingly pronounced for larger model sizes.

Added Efficiency

Moreover, the additional prediction heads trained via multi-token prediction can be leveraged at inference time to enable faster decoding using techniques like speculative decoding. Their best 4-token prediction model achieved up to 3x speedup without compromising quality.

Core Results

Training large language models to predict multiple future tokens at once (multi-token prediction) instead of just the next token results in better sample efficiency and improved downstream performance. On code benchmarks like MBPP and HumanEval, models trained with 4-token prediction solved around 15% more problems on average compared to next-token prediction baselines. The benefits of multi-token prediction increase with larger model sizes, as their 13B parameter model solved 12% more MBPP problems and 17% more HumanEval problems than a comparable next-token model. Multi-token prediction also enabled faster inference speed by allowing techniques like speculative decoding. For natural language tasks like summarization, multi-token models outperformed next-token baselines, though gains were smaller than on code. Controlled experiments showed multi-token prediction improved acquisition of induction capabilities in smaller models and better generalization on arithmetic tasks compared to scaling.

Why Does this Work?

The researchers hypothesize that multi-token prediction encourages better transfer of information across positions and development of algorithmic reasoning capabilities. Through controlled experiments, they demonstrate that multi-token prediction indeed helps smaller models acquire induction heads and generalize better on arithmetic tasks compared to tripling the model size.
Overall, multi-token prediction is a simple yet effective modification that results in stronger and faster transformer language models at no additional training cost. The researchers hope this work will motivate further exploration of novel auxiliary losses beyond just next-token prediction.