Skip to main content

Scaling Llama 2 to 32k Tokens With LongLora

The need for LLMs that can digest long content is becoming increasingly more important. Go beyond 4096 tokens with LongLora!
Created on October 24|Last edited on November 6

Introduction

As open source large language models continue to improve, the need for models that can digest extremely long amounts of context is becoming increasingly important. One of the core limitations of existing LLMs is their fixed context window that limits the maximum input length that the model can accept.
In this article we'll look at how to scale your Llama 2 model to 32k tokens with LongLora.
Before we dive in, here's what we'll be covering:


The problem LLMs have with long sequences

The transformer architecture itself is technically capable of operating on long sequence lengths, it's just that existing implementations of the transformer must be retrained on a very large amount of data at a similar context length in order to properly adapt to the longer sequence lengths. This causes multiple problems:

Problem 1: Quadratic attention

There have been multiple roadblocks on the path to long context language models. First off, the attention mechanism itself scales quadratically both in time and space as sequence length grows. This is a major limitation for training long context LLM's using the vanilla attention mechanism.

Problem 2: Sequence length extension with pre-trained models

The next major hurdle is the limitation of utilizing pre-trained models that were originally trained on smaller sequence length. So let's say you had the weights for an LLM that originally trained for a max sequence length of 4096. Simply using these weights as a starting point to fine-tune on longer sequence lengths (say 8192) has previously not worked well in practice.

Problem 3: Using LoRA on long sequences

Thirdly, lets say you were able to solve the first two issues, and you would like to use Lora to fine tune your long context LLM. It turns out that applying LoRA in the traditional sense for long sequence lengths performs somewhat poorly in comparison to shorter sequences.
Luckily, all of these problems have been solved. I'm going to walk through the solutions that enable training long context LLM's with minimal compute!

Solving these problems

Of course, this wouldn't be much of an article if we just listed problems with no solutions.

Solving Problem 1: Shift Short Attention

Quadratic scaling in transformers

The traditional self-attention mechanism in transformers operates on all pairs of tokens in the input sequence. Specifically, each token is compared to every other token, which results in a quadratic complexity O(n^2), where n is the length of the sequence.

Impact on GPU memory requirements

This quadratic scaling has serious implications for GPU memory:
Memory Usage for Intermediate Calculations: During forward and backward passes, intermediate representations of the attention scores and other tensors have to be stored. These become increasingly large as the sequence length increases, demanding more memory.
To circumvent these challenges, researchers have created Shift Short Attention. This technique involves a shift in the window of attention in the self-attention mechanism, thereby enabling more efficient computation. While standard self-attention operates on all tokens at once, S2-Attn partitions the tokens into smaller groups and then shifts the window of attention for each group.

A diagram of Shift Short Attention from the Long LoraPaper

The interesting feature of Shift Short Attention is the shift. In traditional approaches, each group would be locked into its own isolated chunk of the sequence, limiting the flow of information between groups. By shifting the window of attention, Shift Short Attention allows for a flow of information between adjacent groups. Importantly, this is achieved without a significant increase in computational cost.
Interestingly, similar approaches have been applied successfully for computer vision tasks. A method known as TSM (temporal shift module) has been shown to match the performance of 3D-CNNs for video understanding tasks, utilizing only 2D convolutions along with the temporal shift module.
The TSM Module for video understanding
The resulting Shift Short Attention performance was found to be comparable to models that utilize standard self-attention, as demonstrated through empirical tests, with a fraction of the compute requirements. I was curious how the shift short attention mechanism compared to the standard attention mechanism in terms of memory usage, so I wrote a script that measures GPU VRAM usage at different sequence lengths. Below is the results up to 10GB of VRAM.


Run: Vanilla_att_vs_short_shift_att
1

As can be seen, the memory requirements for Shift Short Attention are significantly less than that of vanilla attention. This is crucial for training long context LLM's.

Solving Problem 2: Positional Embedding Scaling

In the Transformer architecture, positional embeddings are crucial for giving the model information about the order of tokens in a sequence. Unlike recurrent or convolutional layers, the Transformer's attention mechanisms don't inherently process data in order, so they lack a way to account for the positions of tokens within a sequence. Positional embeddings solve this by adding vectors to the input embeddings that are specifically designed to represent each token's position in the sequence. This allows the Transformer to consider both the individual meaning of each token and its position in the sequence when making computations, enabling it to handle a wide variety of sequence-based tasks effectively.
Attempting to train a Transformer model on sequences longer than it was originally trained for can lead to complications. One key issue is that the model's positional embeddings are only designed to handle sequences up to a certain length, which means that longer sequences might not be accurately represented. This could degrade the model's performance or lead to unexpected behavior.
Methods have been previously developed to reduce the the training resources required to extend the context window of a pre-trained language models, however, these methods have limited applicability to some of todays most popular LLM’s like llama 2.
Until recently, the challenge of fine-tuning language models to handle longer sequences than they were originally trained on was somewhat mysterious and poorly understood. However, thanks to the work of Chen et al., who developed Position Interpolation, as well as Kaiokendev, extending the context windows for pre-trained language models with ROPE positional embeddings has become straightforward and computationally efficient.

Positional Embeddings and Context Window Extension

The work of Kaiokendev and Chen et al. showcases that the limitation might not be in the Transformer architecture itself, but in the model's learned behavior related to the positional embeddings. The researchers discovered that by scaling down the frequency window in RoPE by a constant, the model was able to operate seamlessly at significantly longer sequence lengths than before, without requiring a large amount of new data for fine-tuning. This approach, termed 'interpolation,' challenges the conventional wisdom of 'extrapolation,' and has opened doors to context windows extending beyond conventional token limits with minimal computational overhead.

The Importance of Scale

One of the revelations from this work is that the scale factor plays a crucial role. The change was as simple as incorporating two lines of code into the existing Rotary Embedding code. The scale field should be treated as a hyperparameter, meaning the same scale used during training should be applied during inference.
A visualization of position interpolation for rotary embeddings. This minor change allows rapid fine tuning from the original pre-trained model.
As can seen below, the position embeddings are scaled to within a familiar range that was seen by the model when it was originally trained.

Solving Problem 3: LongLora

In addition to this, the researchers also noticed that applying Lora to the models for context extensions tended to struggle with longer sequence lengths. The researchers discovered that by unfreezing the normalization and embeddings layers of the transformers, which only take up a small portion of the models parameters, results in dramatic improvements in the performance of Lora fine tuning, while only marginally increasing the amount of parameters for training. Despite the simplicity of this method, the researchers showed that this adjustment is very effective!

The Code

Choosing The Test Data For LongLora

In order to test out these methods, the researchers that introduced LongLora used the RedPajama dataset. The dataset is available for use on HuggingFace, but it contains over 1 trillion tokens, so I recommend using the togethercomputer/RedPajama-Data-1T-Sample version unless you have a plethora of compute.
In order to test out the methods of LongLora, we will try fine-tuning a model on a portion of the Red Pajama Dataset. The results that most stood out to me in the paper was the LongLora results for a sequence length of 32K tokens.
In order to perform this experiment, we will build off of the LongLora official repo which implements LongLora using the Llama 2 family of models. For these tests, we will use the 7B parameter Llama 2 model. In order to get access to the model, you will need to request access from Meta using the same email as your HuggingFace account, which will allow you to download the models from HuggingFace.
Here is the main training function used for experiments:
def train():
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
# redefine training args

# NOTE: May expand supported model types in the future
if model_args.model_type == "gpt-neox":
replace_gpt_neox_attn(training_args.use_flash_attn)
else:
assert model_args.model_type == "llama", "Only support llama and gpt-neox for now"
replace_llama_attn(training_args.use_flash_attn)
compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=False,
)

# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
quantization_config=bnb_config

)

orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}

# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
max_length=8192 * 4
)

special_tokens_dict = dict()
if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
)

rank = int(os.environ.get('RANK', -1))
if rank > 0:
barrier()

dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir)
dataset['train'] = dataset['train'].select(range(300000))

indices_to_keep = []
for i, example in enumerate(dataset['train']):
if len(example['text']) <= 300000: # filtering excessively long samples
indices_to_keep.append(i)

# Filter the dataset to only keep those examples
dataset['train'] = dataset['train'].select(indices_to_keep)

dataset = dataset.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=50, remove_columns=["text", "meta"], batch_size=200)

if rank == 0:
barrier()


data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

if training_args.low_rank_training:
if model_args.model_type == "gpt-neox":
# added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
targets = ["query_key_value", "dense"]
else:
targets=["q_proj", "k_proj", "v_proj", "o_proj"]

lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=targets,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
# enable trainable params
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]

# redefine training_args for clarity
training_args = TrainingArguments(
output_dir=training_args.output_dir or "./output",
overwrite_output_dir=True,
num_train_epochs=training_args.num_train_epochs or 100,
per_device_train_batch_size=training_args.per_device_train_batch_size or 1,
save_steps=training_args.save_steps or 100,
save_total_limit=training_args.save_total_limit or 2,
fp16=training_args.fp16 or True,
bf16=training_args.bf16 or False,
report_to=training_args.report_to or 'wandb',
logging_steps=training_args.logging_steps or 10,
gradient_checkpointing=training_args.gradient_checkpointing or True,
gradient_accumulation_steps=training_args.gradient_accumulation_steps or 8,
learning_rate=training_args.learning_rate or 5e-5,
cache_dir=training_args.cache_dir or './cache',
deepspeed=training_args.deepspeed or "./ds_configs/stage2.json"
)

model.config.use_cache = False # required for gradient checkpointing
model.enable_input_require_grads() # required for gradient checkpointing
model.gradient_checkpointing_enable() # enable gradient checkpointing
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args,
train_dataset=dataset["train"],
eval_dataset=None,
data_collator=data_collator)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)




Wandb Logging

I've set the logging to go directly to Weights & Biases, as shown in the training arguments.
report_to=training_args.report_to or 'wandb',

Shift Short Attention

We can easily add Shift Short Attention to our model using a function from the LongLora Repo. This essentially replaces the forward method of the model with a similar method that utilized Shift Short Attention in place of standard Attention.
replace_llama_attn(training_args.use_flash_attn)

Rope Scaling

In the code, rope scaling is set as a ratio between the new context length and the original length. As can be seen, the scale value is created by dividing the desired context length by the context length that the model was originally trained on, which sort of "interpolates" the positional embeddings.
scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}

Trainable Normalization and Embedding layers

In order to make the embedding and normalization layers trainable, it’s a matter of looping through the model parameters and making the desired layers (embed and norm) trainable.
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]

DeepSpeed

A key component of the training pipeline is DeepSpeed, which evades those pesky 'cuda out of memory' errors that haunt every ML programmer. I'm using zero stage 2 for this run. I'll share my config file below. Note that this is slightly different from the default config used in the LongLora official repo.
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16":{
"enabled":"auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"round_robin_gradients": true,
"cpu_offload": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

Run Command

Heres the run command for fine-tuning. You may need to adjust this based on your hardware available. I was able to spin up 8 NVIDIA A6000's for this run!
torchrun --nproc_per_node=8 train.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--bf16 False \
--output_dir ./output \
--cache_dir ./cache \
--model_max_length 32768 \
--use_flash_attn False \
--low_rank_training True \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 100 \
--save_total_limit 2 \
--learning_rate 5e-5 \
--weight_decay 0.0 \
--warmup_steps 20 \
--lr_scheduler_type "constant_with_warmup" \
--logging_steps 1 \
--deepspeed "ds_configs/stage2.json" \
--tf32 True \
--max_steps 1000



Training

I was able to train the model for 100 steps on a 300k sample subset of the RedPajama dataset!


Run: azure-sound-62
1


In conclusion, we have successfully scaled Llama to 32k tokens! The pace at which breakthroughs are being made is truly astounding. Context length is a critical component of any LLM, and 32k tokens is only a starting point! One could only imagine the possibilities of scaling to 1 million or even 10 million tokens, and the resulting functionality that could arise from such long context windows is really exciting.
Overall, I hope you enjoyed this guide for training Llama 7b on 32k tokens, and feel free to drop a comment if you have any questions! For the code, feel free to check out the github repo here as well as the official LongLora repo here.

Sources

Tags: Articles, LLM
Iterate on AI agents and models faster. Try Weights & Biases today.