Skip to main content

Zephyr-7B: Fine-Tuning and Inference with W&B

Fine-tuning Zephyr-7B on AgentInstruct and running inference with W&B.
Created on November 5|Last edited on November 21
The purpose of this article is to highlight Zephyr-7B and AgentInstruct, how you can fine-tune your own model through Colab with Hugging Face, and how to run simple inference with W&B!

This article is accompanied by this notebook. Follow along! For the wandb project, click here.
💡

🍃 What is Zephyr-7B?

Zephyr-7B is a pair of models developed by the Hugging Face 4 team stemming from the popular Mistral-7B model: Zephyr-7B-α and Zephyr-7B-β (we'll be using Zephyr-7B-α for this report). These models not only perform better than the Mistral-7B models but are comparable to LLaMA2-Chat-70B, models that are 10 times the size!

The method here can broken down into 3 steps:


What We're Covering



🧑‍🏫 Distilled supervised fine-tuning (dSFT)

Distilled supervised fine-tuning (dSFT) teaches our model to respond to instructions and prompts. It's traditionally done with supervised fine-tuning (SFT) on a high-quality dataset of instructions and responses. Zephyr, instead, leverages a teacher model to generate these high-quality responses, effectively "distilling" some of its capability to our model. You can also think of it as pseudo-labeling.
Say you have a list of seed prompts {x1,...,xj}\{x_1, ..., x_j\}. For each seed prompt xix_i, you both sample a response and refine the prompt with a teacher model (GPT-4) such that you end with (x^i,yi)(\hat{x}_i, y_i). x^i\hat{x}_i  is the refined prompt and yiy_i is the response from the teacher model for a seed prompt xix_i.
The end dataset is: C={(x^i,yi),...,(x^j,yj)}C = \{(\hat{x}_i, y_i), ..., (\hat{x}_j, y_j)\}
The model is then instruction-tuned to optimize for this equation:
πdSFT=maxπ E(x,y)Clogπ(yx)\pi_{dSFT} = \underset{\pi}{max} ~ \underset{(x, y) \sim C}{\mathbb{E}} log \pi(y|x)

TL;DR, this equation is saying find the model (weights) π\pi such that it maximizes the expected log probability of response yy from refined prompt xx (sampled from the dataset CC).

🔊 AI feedback (AIF)

It's common to use human feedback to align LLMs. Zephyr, however, uses AI feedback (AIF).
Starting with a collection of 4 different models like Claude, Llama, Falcon, etc, each prompt xix1,...,xjx_i \sim {x_1, ..., x_j} is fed through all 4 models to produce (yi1,yi2,yi3,yi4)(y^1_i, y^2_i, y^3_i, y^4_i). The teacher model, GPT-4, then gives a score s{1,2,3,4}=πT(xi,yi{1,2,3,4})s^{\{1, 2, 3, 4\}} = \pi_{T}(\cdot|x_i, y_i^{\{1, 2, 3, 4\}}). The highest score of the 4 responses is called ywy_w and a random lower-scoring response is called yly_l.
Thus, from a list of prompts {x1,...,xj}\{x_1, ..., x_j\}, we derive a dataset D={(x1,y1w,y1l),...,(xj,yjw,yjl)}D = \{(x_1, y_1^w, y_1^l), ..., (x_j, y_j^w, y_j^l)\}. These are 3-tuples of prompts with a stronger and a weaker response.

🤓 Distilled direct preference optimization (dDPO)

The authors use the dataset generated from the AIF step, DD, in this final step.
In contrast to RLHF, DPO directly optimizes the model on the preferences without the need for a trained reward model. DPO is lightweight and is more stable according to the authors.
The Zephyr authors call this dDPO because the dataset is distilled from earlier steps, leveraging an AI to provide preference labels.
DPO optimizes for this equation:
πθ=maxπE(x,yw,yl)Dlogσ(βlogπ(ywx)πdSFT(ywx)βlogπ(ylx)πdSFT(ylx))\pi_\theta = \underset{\pi}{max} \underset{(x, y_w, y_l)\sim D}{\mathbb{E}} log \sigma (\beta log \frac{\pi(y_w|x)}{\pi_{dSFT} (y_w|x)} - \beta log \frac{\pi(y_l|x)}{\pi_{dSFT} (y_l|x)})

TL;DR, we tune our LLM to model πdSFT\pi_{dSFT} which was instruction-tuned via dSFT. And we also alignment-tune our LLM to prefer responses that are stronger ywy_w, ranked by our teacher model GPT-4. So this entire 3 step process (dSFT, AIF, dDPO) does instruction and alignment tuning.

📂 Dataset

The authors of Zephyr focused on two dialogue datasets: UltraChat and UltraFeedback.
UltraChat is a 1.47M multi-turn conversation dataset generated by GPT-3.5-Turbo. They filtered this down 200k examples. UltraFeedback is a 64k prompt dataset with 4 LLM responses per prompt rated for instruction-following, honesty, and helpfulness by GPT-4. They selected binary preferences by pairing highest mean scored responses with one of the weaker responses.

📊 Evaluation

They evaluated on 2 benchmarks: MT-Bench and AlpacaEval. As a 3rd benchmark, they tested Zephyr on the Open LLM Leaderboard to check for regressions and truthfulness capabilities.

🤔 Training

The πdSFT\pi_{dSFT} model was trained for one to three epochs with cosine LR scheduler and a max LR of 2e-5 and 10% warmup steps. They used a sequence length of 2048 tokens and a batch size of 512. The DPO model was trained for one to three epochs with a linear LR scheduler, max LR of 5e-7 and 10% warmup steps. It used a batch size of 32 and a β\beta of 0.1. The final Zephyr-7B model initialized with the SFT model weights (trained on 1 epoch) and optimized for 3 DPO epochs.

📈 Results




They conclude:
  • dSFT + dDPO yields best performance
  • dDPO improves chat abilities and academic task performance
  • model overfits after one epoch of DPO training but did not harm downstream performance; SFT model trained for more than one epoch led to regression

🕵 What is AgentInstruct?

Let's move on a moment and talk about another paper, AgentTuning: Enabling Generalized Agent Abilities for LLMs. We're going to use the data here on our Zephyr model. The AgentInstruct paper introduces a couple things.
  • AgentTuning: A general method to enhance agent abilities in LLMs while maintaining generality
  • AgentInstruct: Instruction-tuning dataset with high-quality interactions
  • AgentLM-7B/13B/70B: agent-tuned LLMs

🔧 Method


The first step in the AgentTuning process is creating the dataset, AgentInstruct. The second step is their instruction-tuning method. This diagram will make sense in a couple minutes, I promise!

Creating AgentInstruct

Dataset creation consists of 3 stages covering 6 tasks (AlfWorld, WebShop, Mind2Web, Knowledge Graph, Operating System, Database). Instruction construction, guess what, creates the instructions. Trajectory interaction means simulating conversations/trajectories for these instructions. Trajectory filtering means filtering the best trajectories for the dataset.
  • Instruction Construction: Use existing training splits (AlfWorld, WebShop, Mind2Web, Knowledge Graph) else derive the task (task derivation; for Database) or self-instruct (generate task, solution, and evaluation script; for Operating System)
  • Trajectory Interaction: Simulating trajectories/conversations with GPT-4 and ChatGPT; interactions done with Chain-of-Thought and the ReAct reasoning framework
  • Trajectory Filtering: Filter trajectories/conversations for successful ones; Mind2Web was difficult so they used a lower threshold for success


AgentTuning

The authors aimed to preserve the model's general abilities while also enhancing the LLM's agent-related capabilities. As such they mixed their AgentInstruct dataset with English-only samples from the ShareGPT dataset. This totaled 57,096 GPT-3.5 conversations and 3,670 GPT-4 conversations using a sampling ratio of 1:4 as GPT-4 responses were of much higher quality.
The 2 datasets are DagentD_{agent} and DgeneralD_{general}. η\eta is a hyperparameter for how large of a proportion you want of one dataset over the other.
J(θ)=ηE(x,y)Dagent[logπθ(yx)]+(1η)E(x,y)Dgeneral[logπθ(yx)]J(\theta) = \eta \cdot \mathbb{E}_{(x, y) \sim D_{agent}} [log \pi_\theta (y|x)] + (1 - \eta) \cdot \mathbb{E}_{(x, y) \sim D_{general}} [log \pi_\theta (y|x)]

This gnarly equation simply formalizes the regular autoregressive objective except, depending on where the data comes from DagentD_{agent} or DgeneralD_{general}, it will be weighted differently. In their experiments scanning range(0, 1, 0.1), they found 0.2 yielded the best held-out task results.


🧪 Experiments & Results

If you'd like, you can take a look at the diagrams below or scan the paper for a more thorough analysis. But I'll summarize their results here:
  • AgentLM significant improvements over Llama 2 in held-in and held-out tasks
  • On general tasks, AgentLM is on par with Llama 2 in: knowledge, mathematics, coding, human preferences
  • Ablation study and generation instructions are important




Fine-tuning with W&B 🪄🐝

The fine-tuning code will be minimally explained. For a deep dive into what this code does, check out my other article on Fine-tuning Mistral 7B with W&B!
💡
Ok. Let's put it all together.
We'll fine-tune our sharded Zephyr-7B-α on the AgentInstruct dataset — specifically the "Mind2Web" task. AgentInstruct is a curated dataset of 1,866 AI Agent interactions across 6 real-world tasks.


🔩 Setup

Before we start, let's install all the necessary requirements and import our libraries.
We will use transformers for instantiating the model and tokenizer, datasets for loading in our dataset, bitsandbytes and accelerate for help with quantizing our model and training, and peft for LoRA.
!pip install git+https://github.com/huggingface/transformers -qqq
!pip install datasets -qqq
!pip install bitsandbytes -qqq
!pip install huggingface_hub -qqq
!pip install peft -qqq
!pip install accelerate -qqq
!pip install bitsandbytes -qqq
!pip install trl -qqq
!pip install wandb -qqq

import os

from copy import deepcopy
from random import randrange
from functools import partial

import torch
import accelerate
import bitsandbytes as bnb

from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import (
LoraConfig,
prepare_model_for_kbit_training,
get_peft_model,
PeftModel
)
We'll be using W&B, so don't forget your API key!
import wandb
wandb.login()

🧠 Preparing our model

Next, we'll instantiate our tokenizer and prepare our model. Here's a recap of our model configuration.
  • BitsAndBytesConfig
  • load_in_4bit = True
  • bnb_4bit_use_double_quant = True
  • bnb_4bit_quant_type = "nf4"
  • bnb_4bit_compute_dtype = torch.bfloat16
  • disabled model cache
  • prepare_model_for_kbit_training
  • freezed model weights
  • layer norm and lm head -> fp32
  • enabled gradients for input embeddings
  • enabled gradient checkpointing
  • LoraConfig
  • lora_alpha = 16
  • lora_dropout = 0.1
  • lora_r = 64
  • applied to q, k, o, v, gated, up, down proj layers
model_name = "anakin87/zephyr-7b-alpha-sharded"
tokenizer = AutoTokenizer.from_pretrained(model_name)

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
)

model.config.use_cache = False
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

def find_all_linear_names(model):
cls = bnb.nn.Linear4bit
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

# lm_head is often excluded.
if 'lm_head' in lora_module_names:
lora_module_names.remove('lm_head')
return list(lora_module_names)
modules = find_all_linear_names(model)

lora_alpha = 16
lora_dropout = 0.1
lora_r = 8

peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=modules,
r=lora_r,
bias="none",
task_type="CAUSAL_LM"
)

model = get_peft_model(model, peft_config)
Let's check out how many parameters we are training.
# For `PeftModel`s we can use `get_nb_trainable_parameters` to get the param counts.
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

# Checking memory...
!nvidia-smi



📂 Loading Our Dataset

Let's do a bit of preprocessing. A recap of the preprocessing:
1. Format prompts
2. Tokenize, pad, truncate prompts
3. Shuffle
dataset = load_dataset("THUDM/AgentInstruct")

def format_prompt(sample):
"""Given a sample dictionary with key "conversations", format the conversation into a prompt.

Args:
sample: A sample dictionary from a Hugging Face dataset.

Returns:
sample: sample dictionary with "text" key for the formatted prompt.
"""

INTRO = "Below is a conversation between a user and you."
END = "Instruction: Write a response appropriate to the conversation."

conversations = ""
for response in sample["conversations"]:
from_, value = response["from"], response["value"]
conversations += f"<{from_}>: " + value + "\n"

sample["text"] = "\n\n".join([INTRO, conversations, END])

return sample

def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, dataset: str, seed: int = 42):
# Format each prompt.
print("Preprocessing dataset...")
dataset = dataset.map(format_prompt)

def preprocess_batch(batch, tokenizer, max_length):
return tokenizer(
batch["text"],
max_length=max_length,
truncation=True,
)

# Apply preprocessing to each batch of the dataset & and remove "conversations" and "text" fields.
_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
dataset = dataset.map(
_preprocessing_function,
batched=True,
remove_columns=["conversations", "text"],
)

# Shuffle dataset.
dataset = dataset.shuffle(seed=seed)

return dataset

max_length = 3000
dataset = preprocess_dataset(tokenizer, max_length, dataset)
Finally, let's log our dataset to W&B Artifacts!
# Logging dataset to W&B Artifacts...
run = wandb.init(
project="finetuning_zephyr7b",
name="log_dataset",
)

dataset.save_to_disk("AgentInstruct_prep.hf")
artifact = wandb.Artifact(name="AgentInstruct_prep", type="dataset")
artifact.add_dir("./AgentInstruct_prep.hf", name="train")
run.log_artifact(artifact)

run.finish()
You can find this logged dataset within the run "AgentInstruct_prep" in our project "finetuning_zephyr7b". Make sure to c lick on "Artifacts" on the left-hand navigation bar. This will bring you to the saved artifact!


👾 Training

Let's define the W&B run for our project "finetuning_zephyr7b". Hugging Face, conveniently, has an integration with W&B. We will set an environment variable "WANDB_LOG_MODEL" to "checkpoint" so we will log all model checkpoints.
run = wandb.init(
project="finetuning_zephyr7b", # Project name.
name="run0", # name of the run within this project.
)

os.environ["WANDB_LOG_MODEL"] = "checkpoint" # Log model checkpoints.
Next, we'll define the training arguments and the trainer.
training_args = TrainingArguments(
output_dir="outputs",
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=2e-4,
max_grad_norm=1.0,
max_steps=40,
lr_scheduler_type="linear",
warmup_steps=5,
fp16=True,
logging_strategy="steps",
logging_steps=1,
save_strategy="steps",
save_steps=10,
optim="paged_adamw_8bit",
report_to="wandb"
)

trainer = Trainer(
model=model,
args=training_args,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
train_dataset=dataset["mind2web"],
)
Finally, we train!
results = trainer.train() # Now we just run train()!
run.finish()
Here are our training results.

Run set
3


🚀 Running Inference with W&B 🪄🐝

Let's run inference and log our results to a W&B Table. Since we have a simple LLM pipeline, tables are great for logging inference results like input/output, token usage, and response time.
query = "what's a neural network?"

# Logging inference results to W&B Table...
run = wandb.init(
project="finetuning_zephyr7b",
name="log_inference",
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_input = tokenizer(text=query, return_tensors="pt").to(device)

start_time_ms = datetime.datetime.now().timestamp() * 1000

_ = model.eval()
with torch.no_grad():
out = model.generate(**model_input, max_new_tokens=300)
response_text = tokenizer.decode(out[0], skip_special_tokens=True)

end_time_ms = round(datetime.datetime.now().timestamp() * 1000) # logged in milliseconds
token_usage = len(out[0])

data = [
[
query,
response_text,
end_time_ms - start_time_ms,
token_usage
]
]

table = wandb.Table(data=data, columns=["input", "output", "latency", "token_usage"])
run.log({"Inference": table})

run.finish()
Below is our very own W&B Table for inference embedded right into this W&B Report!

Run set
3


👋 Conclusion

By now, I've covered Zephyr-7B, AgentTuning and AgentInstruct, and I demonstrated how you can fine-tune Zephyr-7B on AgentInstruct with Hugging Face and W&B. I also highlighted how to use W&B Tables to track model inputs and outputs. Thanks for reading! 👋

References

W&B Traces
Zephyr-7B
AgentInstruct
DPO





Iterate on AI agents and models faster. Try Weights & Biases today.