Skip to main content

Meta presents Coconut: Augmenting LLM Reasoning with Latent Thoughts

Created on December 12|Last edited on December 12
Large language models have revolutionized artificial intelligence by demonstrating remarkable capabilities in natural language reasoning. One of the most impactful techniques is Chain-of-Thought (CoT) reasoning, where models break down complex tasks into a sequence of intermediate steps expressed in language. However, CoT reasoning relies heavily on language tokens, which can be verbose, computationally expensive, and inefficient. Enter Coconut (Chain of Continuous Thought), a new framework that fundamentally changes how machines reason. By replacing language steps with compact latent representations called latent thoughts, Coconut enables reasoning to take place in a high-dimensional continuous latent space, free from the constraints of natural language.

Latent Thoughts and the Multi-Stage Curriculum

The core innovation in Coconut lies in its use of latent thoughts—hidden states that encode reasoning steps as continuous vectors rather than language tokens. These latent tokens represent the reasoning context of an entire step and are used as inputs for subsequent steps, bypassing the need to explicitly articulate every detail in natural language. For example, instead of processing a reasoning chain like “Add 2 and 4 to get 6” and “Multiply 6 by 3 to get 18” entirely in language, Coconut replaces these steps with compact latent tokens that represent the same logic.
The transition from language-based reasoning to latent reasoning is achieved through a multi-stage curriculum training process. At the outset, the model operates fully in CoT reasoning mode, where each step is expressed in natural language and consists of multiple tokens. For example, the step “Add 2 and 4 to get 6” might be broken into tokens like ["Add", "2", "and", "4", "to", "get", "6"]. The model generates these steps sequentially, supervised by the CoT training data.

In later training stages, Coconut begins replacing language steps with their latent representations. The model learns to replace each reasoning step with its corresponding final hidden state—a high-dimensional vector produced by the transformer after processing the entire step in language mode. This hidden state serves as the latent token (or continuous thought) summarizing the step. The ground truth for this latent input is the next reasoning step in the language space, ensuring that the latent token encodes the necessary information to predict the next step correctly.
Initially, this replacement is gradual. For example, in Stage 1, only the first reasoning step might be replaced by its latent representation, while the remaining steps remain in natural language. The input would look like this: [Question] <bot> [Latent Token 1] <eot>, and the target output would be [Step 2 in language]. By associating the latent token with the next step in the language space, the model is incentivized to learn meaningful and predictive latent representations.

As training progresses, more reasoning steps are replaced with their latent counterparts. In Stage 2, the input might replace the first two steps with latent tokens: [Question] <bot> [Latent Token 1] <eot> <bot> [Latent Token 2] <eot>, and the target output would be [Step 3 in language]. By the final stage, all reasoning steps are represented as latent tokens, and the model is trained to predict the final answer directly from these latent inputs: [Question] <bot> [Latent Token 1] <eot> ... <bot> [Latent Token N] <eot> → [Answer in language].
The number of latent tokens used to represent each reasoning step is controlled by the hyperparameter c, which starts at c=1 (a single latent token per reasoning step) and increases over time. This allows the model to gradually allocate more capacity to encode complex reasoning steps, distributing the information across multiple latent tokens if needed. For example, when c=2, a reasoning step is represented by two sequential latent tokens, generated iteratively by feeding the model’s previous latent state back into itself.

Efficiency and Breadth-First Search in Latent Space

This approach brings several theoretical advantages. By reasoning in latent space, Coconut avoids the verbosity of language tokens, potentially reducing computational overhead. Moreover, latent thoughts enable the model to explore multiple reasoning paths simultaneously, akin to a breadth-first search, instead of committing prematurely to a single deterministic chain. This flexibility could allow Coconut to perform well for tasks that require backtracking, planning, or multi-path exploration.
Critically, Coconut maintains reasoning quality by leveraging indirect supervision. The latent tokens themselves are not directly supervised, but their quality is optimized through the loss on downstream predictions, such as the next reasoning step or final answer. This ensures that latent tokens remain tightly aligned with the task, encoding information in a way that supports accurate predictions.

Experiments and Results

The experiments in the Coconut framework utilized a pre-trained GPT-2 model as the base architecture. GPT-2, a transformer-based model, was chosen for its versatility and established effectiveness in natural language processing tasks.
To validate Coconut’s effectiveness, the framework was tested on three datasets: GSM8k for math reasoning, ProntoQA for logical reasoning, and ProsQA for proof-based reasoning. Results showed that Coconut significantly outperformed baseline approaches in both accuracy and efficiency. On GSM8k, Coconut achieved 34.1% accuracy, surpassing iCoT’s 30.0%, while requiring fewer tokens. On ProntoQA, Coconut demonstrated near-perfect accuracy of 99.9% with just three latent tokens per reasoning step. On the more complex ProsQA dataset, designed to challenge planning and backtracking abilities, Coconut achieved 96.6% accuracy, far exceeding traditional CoT’s 76.7%.

Coconut’s transition to latent reasoning represents a paradigm shift in how machines approach complex tasks. By breaking free from the limitations of language, it allows models to think faster, more efficiently, and with greater flexibility. As researchers continue to refine this framework, questions remain about how to interpret latent thoughts and how this method scales to diverse and complex datasets. However, one thing is certain: Coconut opens the door to a future where AI reasoning is no longer constrained by words but instead flourishes in the vast potential of latent space.



Tags: ML News
Iterate on AI agents and models faster. Try Weights & Biases today.