Skip to main content

LoRA vs semi-FT vs Full FT

A quick comparison of fine-tuning techniques
Created on November 16|Last edited on October 5

Recipe

  • Base model: Llama2-hf
  • Alpaca finetune for 3 epochs.
  • Packing sequences of length 1024
  • 8-layers finetune is freezing the first 24 layers of the model and training the last 8 and the classification head, as explained here
  • Batch size is 16 with grad_accum=2 except for LoRA, which is 8 with grad accum 4 (OOM) -> All have 32 as effective batch size.
  • Everything is using trl SFTTrainer besides full-ft-pytorch: is a pure-pytorch training loop.

LoRA config

The exact execution of my LoRA run is captured here (thanks W&B).
from peft import LoraConfig, get_peft_model

model_id = 'meta-llama/Llama-2-7b-hf'

model_kwargs = dict(
device_map=0,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
use_cache=False,
)

peft_config = LoraConfig(
r=64, # the rank of the LoRA matrices
lora_alpha=16, # the weight
lora_dropout=0.1, # dropout to add to the LoRA layers
bias="none", # add bias to the nn.Linear layers?
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj","v_proj","o_proj"], # the name of the layers to add LoRA
)

# model construct
from llm_recipes.utils import LLMSampleCB

output_dir = "./output/"
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size//2,
bf16=True,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio = 0.1,
max_steps=total_num_steps,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=True,
gradient_checkpointing_kwargs=dict(use_reentrant=False),
evaluation_strategy="steps",
eval_steps=total_num_steps // num_train_epochs,
# eval_steps=10,
# logging strategies
logging_strategy="steps",
logging_steps=1,
save_strategy="steps",
save_steps=total_num_steps // num_train_epochs,
)

trainer = SFTTrainer(
model=model_id,
model_init_kwargs=model_kwargs,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
packing=True,
max_seq_length=1024,
args=training_args,
formatting_func=create_prompt,
peft_config=peft_config,
)

Results


Run set
4