SUPRA: Uptraining Transformers to Recurrent Neural Networks for Efficient Inference
Improvements in linear language modeling!
Created on May 17|Last edited on May 17
Comment
Transformers have established themselves as the premier model architecture, particularly for their remarkable performance on a variety of tasks. However, transformers' memory-intensive nature and exponential scaling inference cost with the number of tokens present significant challenges. To address these issues, the paper "Linearizing Large Language Models" introduces an innovative approach named Scalable UPtraining for Recurrent Attention (SUPRA). This method leverages pre-trained transformers and converts them into Recurrent Neural Networks, maintaining the benefits of pre-training while achieving efficient inference.
Understanding the Challenges of Transformers
Transformers face high inference costs that grow linearly with the sequence length. In contrast, RNNs offer fixed-cost inference due to their ability to maintain a constant-size hidden state, making them attractive for tasks requiring efficient and scalable inference.
The concept of linear transformers was introduced to mitigate the computational overhead of the standard softmax attention mechanism. Linear transformers replace softmax with a linear similarity function, which can be reformulated to work like an RNN. Despite these efforts, linear transformers still underperform compared to softmax transformers on many benchmarks, primarily due to stability issues and the complexity of pre-training.
SUPRA: The Uptraining Approach
What is Uptraining? Uptraining refers to the process of adapting a pre-trained model to a new architecture with minimal additional training, as opposed to fine-tuning, which typically involves retraining the model on a different dataset. SUPRA specifically focuses on converting pre-trained transformers into RNNs by modifying their attention mechanisms.
Linearizing the Attention Mechanism: The core of SUPRA lies in transforming the attention mechanism of transformers into a recurrent form. The softmax attention is replaced with a linear function that allows for recurrent updates. This transformation is crucial because it enables the model to update its state incrementally, similar to an RNN.

State Updates: The recurrent state (s) and a normalization factor (z) are updated at each time step using the transformed keys (k) and values (v). These updates are performed in a manner that allows the model to process sequences one token at a time, significantly reducing inference costs.
Detailed Steps of SUPRA

where phi(x) is defined as:

Initialization: Initialize the recurrent state and normalization factor to zero.
Transformation of Queries and Keys: For each token in the input sequence, compute the query and key using learned weight matrices. Apply an MLP kernel and rotary positional embeddings (RoPE) to these vectors to ensure they handle sequential data effectively.
Recurrent State Updates: Update the state and normalization factor using the transformed keys and values. The transformation applied by the MLP kernel adjusts the keys and values appropriately.
Computing Attention Outputs: Compute the attention output at each time step by normalizing the dot product of the transformed query and the state. Apply GroupNorm to the output for stability, replacing the traditional softmax normalization.
Repeat the above steps for each token in the sequence, ensuring the model handles sequences in a recurrent manner.
Fine-Tuning the Model
After converting the transformer into an RNN using the above steps, the model is fine-tuned on a smaller dataset. This fine-tuning process adjusts the weights of the new components (MLP kernel, GroupNorm, etc.) to optimize performance. Importantly, this step requires only a fraction of the compute resources compared to pre-training from scratch.
Results of SUPRA
The researchers tested the SUPRA method by uptraining a range of models from 1B to 7B parameters, including Llama2 and Mistral models. These models were tested on standard language understanding benchmarks and long-context evaluations to assess their performance and limitations.

Language Modeling Results Against Other Linear Models
The models were evaluated using the Eleuther evaluation harness on standard natural language understanding (NLU) tasks. Notably, SUPRA-processed models initialized from strong pre-trained transformers (Llama2-7B and Mistral-7B) maintained high performance across most benchmarks, outperforming other linear models like RWKV-5 with minimal uptraining.
SUPRA's performance was benchmarked against other pre-trained recurrent models and models trained from scratch. For instance, a Mamba model trained from scratch on 1.2 trillion tokens served as a strong baseline. SUPRA models, despite using significantly less training data, achieved competitive results, highlighting the efficiency of the uptraining process.
Long-Context Evaluations
Long-context tasks are challenging for many models due to the necessity of retaining information over extended sequences. SUPRA models were evaluated on tasks from the SCROLLS benchmark, such as Qasper and NarrativeQA, at various context lengths. The performance of these models was compared with their training context lengths and beyond.
SUPRA models demonstrated a capacity to maintain performance beyond their training context lengths, a trait typically associated with recurrent models. However, transformers without modifications generally outperformed SUPRA models at their maximum training context length. This indicates that while SUPRA models are effective, there is room for improvement in handling very long contexts.
Key Insights and Comparisons
One significant finding was the importance of normalization in maintaining model performance during uptraining. SUPRA's use of GroupNorm instead of softmax normalization provided better stability and performance, particularly in larger models.
The experiments showed that pre-training a linear model on a vast dataset still yields better results compared to fine-tuning a softmax model trained on the same budget. This underscores the advantage of starting with strong pre-trained transformers and using the SUPRA method for efficient conversion.
Despite the promising results, SUPRA models exhibited persistent performance gaps in in-context learning tasks like MMLU. This suggests that linear models, even when uptrained from robust transformers, may have inherent limitations that need addressing through more sophisticated mechanisms or training strategies.
Conclusion
The SUPRA method provides a cost-effective and efficient approach to leveraging pre-trained transformers for tasks requiring recurrent neural networks. By converting transformers to RNNs, SUPRA maintains high performance on standard benchmarks and shows potential for long-context tasks. However, further research is needed to fully overcome the limitations in in-context learning and long-context performance.
Add a comment
Tags: ML News
Iterate on AI agents and models faster. Try Weights & Biases today.