Skip to main content

Fine-Tuning Llama-3 with LoRA: TorchTune vs HuggingFace

A battle between the HuggingFace and TorchTune!!!
Created on April 19|Last edited on November 21
Meta recently unveiled Llama 3, the latest iteration of its flagship open-source LLM, which comes in two variants boasting 8 billion and 70 billion parameters, respectively. These models are designed to perform exceptionally well in language understanding, providing deep contextual insights and handling complex tasks such as translation and dialogue generation with ease.
With a training corpus that spans over 15 trillion tokens, Llama 3 can process information with a significant context length of 8,000 tokens, doubling the capacity of its predecessor.
In this article, we will fine-tune Llama 3 using TorchTune and HuggingFace, and see which wins!


What We'll Cover



Llama 3 Benchmarks

The benchmarks for Llama 3 are impressive. The 8B variant achieves 68.4% on MMLU 5-shot, 34.2% on GPQA 0-shot, and 62.2% on HumanEval 0-shot. The 70B variant demonstrates even more extraordinary performance, scoring 82.0% on MMLU 5-shot, 39.5% on GPQA 0-shot, and 81.7% on HumanEval 0-shot.


The TorchTune Library

Fine-tuning large language models has become essential for achieving top-tier results across various tasks. TorchTune, a library specifically crafted to streamline the fine-tuning process on the PyTorch framework, stands out in this field. It provides tools and functionalities specifically designed to optimize, manage, and deploy fine-tuned models effectively.

What is TorchTune?

TorchTune is a robust library facilitating the fine-tuning of LLMs on PyTorch. It is designed to make adapting pre-trained models to specific tasks both accessible and efficient, addressing common challenges such as managing training workflows, optimizing memory usage, and integration with model tracking systems like Weights & Biases.

Key Features of TorchTune

Configurability: TorchTune uses YAML configuration files, allowing users to specify and modify training parameters, model components, and dataset details easily.
Advanced Memory Management: Features such as activation checkpointing and support for reduced precision training help reduce the memory footprint significantly.
Integration with Distributed Computing: TorchTune supports distributed training, crucial for scaling the fine-tuning process across multiple GPUs or nodes.

How Does TorchTune Work?

TorchTune operates through defined components and stages, from setting up via YAML configuration files to training, checkpointing, logging, and deployment. It integrates seamlessly with PyTorch, allowing the models to be easily exported for production use.

What Is LoRA?

Low-Rank Approximation (LoRA) is a method that has recently gained traction in the field of machine learning, particularly in the training and adaptation of large language models. LoRA focuses on enhancing the adaptability and efficiency of these models through targeted modifications that do not require altering the entire network.
LoRA operates on the premise that the primary parameters of a pre-trained neural network can remain fixed, and only a small subset of parameters is optimized during the fine-tuning phase. This is achieved by introducing low-rank matrices into specific parts of the model, such as the attention and feed-forward layers of transformer architectures. These matrices are much smaller than the main model parameters, which means they require less computational resources to update.
Adding Adapter Layers

Benefits of Using LoRA

Efficiency: By limiting the number of trainable parameters, LoRA reduces the computational overhead associated with training large models. This is particularly beneficial in environments with limited GPU resources.
Flexibility: LoRA allows for rapid adaptation of pre-trained models to new tasks or datasets without the need for extensive retraining.
Preservation of Pre-trained Knowledge: Since the core structure of the model remains unchanged, LoRA preserves the rich representations learned during the extensive pre-training phase, ensuring that the fine-tuned model does not deviate significantly from its original capabilities.

Implementation of LoRA

Implementing LoRA involves augmenting the existing layers of a model with trainable low-rank matrices. For instance, in a transformer model, LoRA matrices can be applied to the projection matrices within the attention mechanism. This allows the model to adapt these matrices specifically while keeping the rest of the architecture intact.

Fine-Tuning Llama 3: TorchTune vs HuggingFace

The experiment setup includes training on a single NVIDIA A6000 GPU with 48GB of VRAM, fully utilizing the GPU VRAM resources on the Alpaca dataset for a single epoch. We will compare the performance of the Llama 3 model when fine-tuned using TorchTune with a LoRA-based approach against a similar setup using Hugging Face's transformers library. The goal is to obtain an apples-to-apples comparison of the two libraries in terms of total throughput.
I tried to match up all parameters in both runs. I'm sure there are some mild discrepancies between the two in areas like the data loaders, however, I think for most intents and purposes, it's a pretty equal setup for both libraries. Additionally, I was able to achieve full resource utilization with a batch size of 1 for HuggingFace, and a batch size of 17 for TorchTune. Due to this large difference, I used 17 gradient accumulation steps for the HuggingFace run. Additionally, I used a relatively short sequence length of 256 tokens - this was to enable the model to train on a single GPU with the HuggingFace Trainer.

The Code

Ok, now we are ready to dive into the code. For TorchTune, there are a few steps that we will need to follow to start fine-tuning. First, start by creating a folder where the weights will be stored, using the command mkdir models.
Next, we will need to obtain our Llama 3 weights. We can do this by running the following command:
tune download meta-llama/Meta-Llama-3-8B --output-dir ./models --hf-token <HF_TOKEN>
Once the weights have been downloaded, we will create our config file. Create a file called llama3.yaml and add the following contents to it:

# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_8b
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /home/models/original/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: ./models
checkpoint_files: [
/home/models/original/consolidated.00.pth
]
recipe_checkpoint: null
output_dir: ./out
model_type: LLAMA3
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
train_on_input: True
max_seq_len: 256
seed: null
shuffle: True
batch_size: 47

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 64
compile: False

# Logging
output_dir: ./out/lora_finetune_output
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
log_dir: torchtune-llama3
log_every_n_steps: 5

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True

# Profiler (disabled)
profiler:
_component_: torchtune.utils.profiler
enabled: False
This config file is what TorchTune uses to parse the arguments used for our training run. We utilize Weights & Biases for logging and monitoring our training process, which will be crucial for tracking our model’s performance and making any necessary adjustments. The use of Weights & Biases allows us to visualize metrics in real-time, adding a layer of transparency and control over the training cycle. Here is the portion of the config that will enable logging with Weights & Biases:
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
log_dir: torchtune-llama3
log_every_n_steps: 5
We also specify the dataset, shown below:
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
train_on_input: True
max_seq_len: 256
Additionally, I can declare my parameters for LoRA in the config as well. Here is an overview of the parameters I used:
lora_attn_modules: Specifies which components of the attention mechanism LoRA is applied to. Common choices include 'q_proj' (query projection) and 'v_proj' (value projection), as these are key areas where adaptations can significantly impact model performance by modifying how inputs are weighted and combined.
lora_rank: Determines the rank of the low-rank matrices that are used to approximate the original weight matrices in the specified modules. A lower rank means fewer parameters to train, which can speed up training and reduce overfitting but may limit the model’s capacity to learn complex patterns.
lora_alpha: A scaling factor that adjusts the magnitude of the updates applied through the LoRA parameters. This can be critical for balancing the influence of the LoRA-enhanced components on the model’s behavior.
With the configuration file ready, the next step is to launch the training process. Execute the training command as follows:
tune run lora_finetune_single_device --config llama3.yaml
This command initiates the training session using the configurations specified in the llama3.yaml file. The script sets up the environment, loads the model and tokenizer, prepares the dataset, and enters the training loop according to the defined epochs, batch size, learning rate schedule, and other parameters. If you are interested in the details of the implementation of TorchTune, I recommend you check out the TorchTune repository on Github, as it's very well organized and concise!
As the model trains, you can monitor various metrics such as loss and learning rate through the Weights & Biases dashboard. Here are the results of my training run:

Run: honest-snow-19
1

The run took about an hour and forty minutes to complete for one epoch!


Running Inference with TorchTune

Now we are ready to test out our model! After training, you will see 2 model files written to your specified output directory. The file meta_model_0.pt is the weights for your fine-tuned model, with the Lora weights already merged in! Now we need to make a config file for our inference, and we can test out our model! Here is my config file for inference:
# Model Arguments
model:
_component_: torchtune.models.llama3.llama3_8b


checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: ./out
checkpoint_files: [
/home/out/meta_model_0.pt
]
# adapter_checkpoint: adapter_0.pt
recipe_checkpoint: null
output_dir: ./out
model_type: LLAMA3



device: cuda
dtype: bf16

seed: 1234

# Tokenizer arguments
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /home/models/original/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: "Hello, my name is"
max_new_tokens: 256
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

quantizer: null
As can be seen, we pass the path to our model, along with information about the prompt, number of tokens to generate, temperature, and top k, which limits the stochastic token selection to a fixed set of the most k likely next-token predictions. Now we can run the following command, which will run inference using the model!
Now we can run the following command:
tune run generate --config ./gen_config.yaml
Now after a short bit, we will see the response!


Training Llama 3 With HuggingFace

Ok, now we are ready to test out a similar training run utilizing the HuggingFace transformers library. I've prepared a script capable of reading from our previous TorchTune config file. The script loads the model and dataset with HuggingFace and uses the SFTTrainer to train our model. I'll share the full script below:
import os
from dataclasses import dataclass, field
from typing import Optional
from datasets.arrow_dataset import Dataset
import torch
from datasets import load_dataset
from peft import LoraConfig
from peft import AutoPeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
AutoTokenizer,
TrainingArguments,
)

from trl import SFTTrainer

torch.manual_seed(42)

@dataclass
class ScriptArguments:
"""
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""

local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})

per_device_train_batch_size: Optional[int] = field(default=1)
per_device_eval_batch_size: Optional[int] = field(default=4)
gradient_accumulation_steps: Optional[int] = field(default=17)
learning_rate: Optional[float] = field(default=3e-4)
max_grad_norm: Optional[float] = field(default=0.3)
weight_decay: Optional[int] = field(default=0.01)
lora_alpha: Optional[int] = field(default=16)
lora_dropout: Optional[float] = field(default=0.0)
lora_r: Optional[int] = field(default=8)
max_seq_length: Optional[int] = field(default=256)
model_name: Optional[str] = field(
# default="mistralai/Mistral-7B-Instruct-v0.1",
default="meta-llama/Meta-Llama-3-8B",
# default="TinyLlama/TinyLlama-1.1B-step-50K-105b",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
}
)
dataset_name: Optional[str] = field(
default="tatsu-lab/alpaca",
metadata={"help": "The preference dataset to use."},
)

use_4bit: Optional[bool] = field(
default=True,
metadata={"help": "Activate 4bit precision base model loading"},
)
use_nested_quant: Optional[bool] = field(
default=False,
metadata={"help": "Activate nested quantization for 4bit base models"},
)
bnb_4bit_compute_dtype: Optional[str] = field(
default="float16",
metadata={"help": "Compute dtype for 4bit base models"},
)
bnb_4bit_quant_type: Optional[str] = field(
default="nf4",
metadata={"help": "Quantization type fp4 or nf4"},
)
num_train_epochs: Optional[int] = field(
default=1,
metadata={"help": "The number of training epochs for the reward model."},
)
fp16: Optional[bool] = field(
default=False,
metadata={"help": "Enables fp16 training."},
)
bf16: Optional[bool] = field(
default=True,
metadata={"help": "Enables bf16 training."},
)
packing: Optional[bool] = field(
default=False,
metadata={"help": "Use packing dataset creating."},
)
gradient_checkpointing: Optional[bool] = field(
default=False,
metadata={"help": "Enables gradient checkpointing."},
)
optim: Optional[str] = field(
default="adamw_torch",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: str = field(
# default="cosine_with_warmup",
default="cosine",

metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
)
max_steps: int = field(default=10000000000, metadata={"help": "How many optimizer update steps to take"})
warmup_steps: int = field(default=100, metadata={"help": "# of steps to do a warmup for"})
group_by_length: bool = field(
default=True,
metadata={
"help": "Group sequences into batches with same length. Saves memory and speeds up training considerably."
},
)
save_steps: int = field(default=200, metadata={"help": "Save checkpoint every X updates steps."})
logging_steps: int = field(default=5, metadata={"help": "Log every X updates steps."})
merge_and_push: Optional[bool] = field(
default=False,
metadata={"help": "Merge and push weights after training"},
)
output_dir: str = field(
default="./results_packing",
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]


def gen_batches_train():
ds = load_dataset(script_args.dataset_name, streaming=True, split="train")


for sample in iter(ds):

# Extract instruction and input from the sample
instruction = str(sample['instruction'])
input_text = str(sample['input'])
out_text = str(sample['output'])
formatted_prompt = None
if input_text is None or input_text == "":
formatted_prompt = (
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
f"<|eot_id|><|start_header_id|>asssitant<|end_header_id|>\n\n",
f"{str(out_text)}"
f"<|eot_id|><|end_of_text|>"
)
else:
formatted_prompt = (
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
f"<|eot_id|><|start_header_id|>asssitant<|end_header_id|>\n\n"
f"{str(out_text)}"
f"<|eot_id|><|end_of_text|>"
)
formatted_prompt = "".join(formatted_prompt)
yield {'text': formatted_prompt}



def create_and_prepare_model(args):
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
# commented qlora stuff
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=args.use_4bit,
# bnb_4bit_quant_type=args.bnb_4bit_quant_type,
# bnb_4bit_compute_dtype=compute_dtype,
# bnb_4bit_use_double_quant=args.use_nested_quant,
# )

if compute_dtype == torch.float16 and args.use_4bit:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
print("=" * 80)

# Load the entire model on the GPU 0
# switch to `device_map = "auto"` for multi-GPU
device_map = {"": 0}

model = AutoModelForCausalLM.from_pretrained(
args.model_name,
# quantization_config=bnb_config,
device_map=device_map,
use_auth_token=True,
)
peft_config = LoraConfig(
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
# target_modules=["query_key_value"],
r=script_args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=['q_proj', 'v_proj'],
)

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

return model, peft_config, tokenizer


training_arguments = TrainingArguments(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.per_device_train_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
optim=script_args.optim,
save_steps=script_args.save_steps,
logging_steps=script_args.logging_steps,
learning_rate=script_args.learning_rate,
fp16=script_args.fp16,
bf16=script_args.bf16,
max_grad_norm=script_args.max_grad_norm,
max_steps=script_args.max_steps,
warmup_steps=script_args.warmup_steps,
group_by_length=script_args.group_by_length,
lr_scheduler_type=script_args.lr_scheduler_type,
report_to=script_args.report_to,
)

model, peft_config, tokenizer = create_and_prepare_model(script_args)

train_gen = Dataset.from_generator(gen_batches_train)

tokenizer.padding_side = "right"

trainer = SFTTrainer(
model=model,
train_dataset=train_gen,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=script_args.packing,
)

trainer.train()

if script_args.merge_and_push:
output_dir = os.path.join(script_args.output_dir, "final_checkpoints")
trainer.model.save_pretrained(output_dir)

# Free memory for merging weights
del model
torch.cuda.empty_cache()

model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()

output_merged_dir = os.path.join(script_args.output_dir, "final_merged_checkpoint")
model.save_pretrained(output_merged_dir, safe_serialization=True)

Here, we use the peft library to implement LoRA, and use the same parameters for LoRA as the TorchTune run. Additionally, I passed 'wandb' as the 'report_to' argument in the HuggingFace trainer, which will automatically log our results to Weights & Biases. To run the script, enter the command below:
python train.py
This will train your Llama 3 model using Huggingface Transformers! Here are the results from my training run:

Run: major-rain-130
1


The run took about two hours and twenty-one minutes. Overall, TorchTune outperformed the HuggingFace Trainer by about 20 percent in terms of runtime to complete 1 epoch on the tatsu-lab/alpaca dataset (available via HuggingFace). I spoke to some of the developers of TorchTune, and they mentioned that even further speedup could possibly be achieved by slightly reducing the batch size from its maximum value (to avoid malloc retries). Additionally, they mentioned that there are future updates currently in the works for TorchTune that will likely boost throughput even further.

Further Tuning

Depending on the initial outcome, you may find it necessary to tweak certain parameters. This could involve adjusting the learning rate, changing the batch size, modifying the number of gradient accumulation steps, or altering the model’s LoRA settings like lora_rank and lora_alpha. Each of these adjustments could potentially improve the model's performance or efficiency.

Overall

Meta's unveiling of Llama 3 signifies a notable advancement in the field of large language models, showcasing both the rapid progression and the ambitious scale at which AI development is proceeding. This new model, which includes versions with 8 billion and 70 billion parameters, is a major milestone.
Furthermore, the introduction of TorchTune, a fine-tuning library for PyTorch, and innovations like LoRA reflect a growing focus on not just enhancing AI capabilities but also on optimizing the efficiency and adaptability of these models for practical application. These advancements promise to significantly expand the potential of LLMs to understand and interact in nuanced human-like ways, marking a leap forward that could reshape how we interact with technology across various domains.
If you are interested in the source code, I created a repo here! I hope you enjoyed this tutorial!



Iterate on AI agents and models faster. Try Weights & Biases today.