Skip to main content

Fine-tuning Mistral 7B with W&B

Finetuning Mistral 7B on the Puffin dataset with LoRA.
Created on October 3|Last edited on October 12
This article will be the concise version of this notebook. Follow along! For the wandb project, click here.
💡

Introduction

Before we dive in, let's answer the question:

What is Mistral 7B and Mistral.ai?

Mistral.ai is an open-source startup working on LLM technology located in Paris. Mistral-7B is their first model in the, you guessed it, 7 billion parameter range. What makes it so special? It's 7 billion parameters but produces results on par with 13B parameter models and even outperforms the LLaMA 2 7B model!


Below are the benchmarks they evaluated their model on.


One interesting statistic they examined was how this model compared to the performance of differently sized LLaMA 2 models. They perform well against models 2 to 5 times their size!

Some of the notable aspects of Mistral 7B (taken from their announcement) include:
  • Outperforms Llama 2 13B on all benchmarks
  • Outperforms Llama 1 34B on many benchmarks
  • Approaches CodeLlama 7B performance on code while remaining strong at English tasks
  • Uses Grouped-query attention (GQA) for faster inference
  • Uses Sliding Window Attention (SWA) to handle longer sequences at smaller cost
Great! How do we use it? 🤔
Before we answer that question, here's what we'll be covering in this article:


Ready? Let's dive in!
There are a couple of ways to use Mistral 7B. They have an instruction-tuned version of it, which can be accessed via OpenAI. Check here for more information. Their model can also be deployed to AWS, Azure, GCP, and OVH. But for this article, we will be using their HuggingFace saved models, specifically the base Mistral 7B model.
Let's get started! 🤓

🔩Setting Up the Environment for Mistral 7B

What do we need, hardware and software-wise, to start fine-tuning Mistral 7B?
For our project, we won't need any crazy amount of compute. Even just Google Colab will do (surprised me at first!), though I'm providing the code if you prefer Jupyter Notebook or other. The Mistral docs page also has recommendations on memory capacity. A T4 on Google Colab will be enough.
Below is a list of the dependencies we need.
import os

from copy import deepcopy
from random import randrange
from functools import partial

import torch
import accelerate
import bitsandbytes as bnb

from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from transformers.integrations import WandbCallback
from peft import (
LoraConfig,
prepare_model_for_kbit_training,
get_peft_model,
PeftModel
)
from trl import SFTTrainer
We use everything here except for SFTTrainer and WandbCallback (it's already integrated internally into Hugging Face).
I won't go too in-depth on these methods as I cover them thoroughly in the notebook, but I'll give a brief overview of the most notable classes/methods. Their docs page contains all the libraries within the Hugging Face ecosystem.
  • transformers is HuggingFace's most popular library and their hub for models and training, evaluation, preprocessing, and other pipeline components.
  • datasets gives us the power to load in any dataset from the dataset hub.
  • peft is HuggingFace's parameter-efficient fine-tuning library, especially useful for LLMs and limited hardware.
  • trl is HuggingFace's RL training library for language models.
  • accelerate is for distributed configuration and accelerating your PyTorch script.
  • bitsandbytes is an HuggingFace-integrated library for quantization functions to help with reducing our memory footprint.
❗Don't forget you also need a W&B API key and a HuggingFace API key!

🧠Define The Mistral 7B Model

Let's first define the model (it's more complicated than the data loading), then load and preprocess the dataset. We will use a sharded Mistral 7B to save memory.
model_name = "someone13574/Mistral-7B-v0.1-sharded"

🎫Tokenizer

The tokenizer is defined as any other tokenizer. I set the pad_token to the eos_token because of this StackOverflow post.
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
The next step is to define our BitsAndBytesConfig. This will significantly reduce memory consumption when we load in our sharded Mistral 7B model.

🐏Bits and Bytes Config & Loading the Model

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
Here's a high-level overview of what this configuration specifies:
  • load in 4 bits
  • double quantize (quantize the weights and quantize the first quantization's constants)
  • use NF4 (normalized fp4)
  • compute type is float16 (computations run in float16)
Finally, we can load our model!
Note that I disable caching because this conflicts with enabling gradient checkpointing.
💡
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto", # Auto selects device to put model on.
)
model.config.use_cache = False

🔃Enable Gradient Checkpointing

Next, we enable gradient checkpointing. This can be done in several ways. The simplest way is to run the below.
model.gradient_checkpointing_enable()
Alternatively, you can enable it later by passing in gradient_checkpointing=True to the TrainingArguments class. But for our Mistral 7B fine-tuning task today, we will use a different method: prepare_model_for_kbit_training.
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) # Explicitly specify!
This function does the following:
  • freezes the model weights
  • cast all non INT8 parameters (layer norm and lm head) to fp32 if the model is not gptq quantized
  • enable_input_require_grads: Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping the model weights fixed.
  • gradient_checkpointing_enable

♑Using LoRA

Finally, we can add the LoRA configuration.
We first need to get a list of the layers in the model that we want to apply LoRA to. Credit to this article!
def find_all_linear_names(model):
cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

# lm_head is often excluded.
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)

modules = find_all_linear_names(model)
This function will return a list of layer names for LoRA to be applied to. These include the q, k, o, v proj layers and the gated, up, and down layers in the MLPs.
['v_proj', 'down_proj', 'up_proj', 'o_proj', 'q_proj', 'gate_proj', 'k_proj']
Then, we can instantiate the LoraConfig.
lora_alpha = 16
lora_dropout = 0.1
lora_r = 8

peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=modules,
r=lora_r,
bias="none",
task_type="CAUSAL_LM"
)
A quick aside, if you don't know what LoRA is, I encourage you to:
The parameters are described as followed:
  • r: the rank of the update matrices, expressed in `int`. Lower rank results in smaller update matrices with fewer trainable parameters.
  • lora_alpha: The alpha parameter for Lora scaling.
  • lora_dropout: The dropout probability for Lora layers.
  • bias: Bias type for Lora. Can be ‘none’, ‘all’ or ‘lora_only’. If ‘all’ or ‘lora_only’, the corresponding biases will be updated during training. Be aware that this means that, even when disabling the adapters, the model will not produce the same output as the base model would have without adaptation.
  • task_type: one of {SEQ_CLS, TOKEN_CLS, CAUSAL_LM, SEQ_2_SEQ_LM, QUESTION_ANS, FEATURE_EXTRACTION}
Finally, we can create our LoRA-applied model.
model = get_peft_model(model, peft_config)
We can also take a look at the parameters and the memory usage.
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

!nvidia-smi

For a 7 billion parameter model, this is pretty good!

📂The Puffin Dataset

We will be fine-tuning Mistral 7B on the Puffin dataset, 3000 multi-turn conversations between a user and GPT-4.

🔬Preview

dataset = load_dataset("LDJnr/Puffin", split="train")
We load in the training split for Puffin. Let's take a quick preview of the dataset.
random_sample = dataset[randrange(len(dataset))]
print(type(random_sample))
print(random_sample.keys())
print(random_sample['id'])
print(random_sample["conversations"])
# The output is cut off for convenience.

>>> <class 'dict'>
>>> dict_keys(['id', 'conversations'])
>>> 425
>>> [{'from': 'human', 'value': 'What is the most effective method for producing biofuels from agricultural waste such as corn stover, sugarcane bagasse, or switchgrass? Evaluate the process from a chemical engineering perspective, taking into account factors such as cost, efficiency, and environmental impact.'}, {'from': 'gpt', 'value': 'The most effective method ...'
The dataset has 3000 dictionaries/instances. Each dictionary is structured in the following way.
{
'id': 1548, # Some number here
'conversations': [
{'from': 'human', 'value': '<human text>'},
{'from': 'gpt', 'value': '<gpt text>'},
...
]
}
Let's zoom in on the conversation!
print(len(random_sample["conversations"]))
print(random_sample["conversations"][0])
print(random_sample["conversations"][1])

>>> 2
>>> {'from': 'human', 'value': 'What is the most effective method for...'}
>>> {'from': 'gpt', 'value': 'The most effective method for producing...'}
Keep in mind, these are multi-turn conversations. They don't end after GPT-4 responds to the user's first message.
for i in dataset:
if len(i["conversations"]) > 2:
for j in i["conversations"]:
print(j) # Conversations are multi-turn (>= 2) and always even in count (human then gpt response).
break

>>> {'from': 'human', 'value': 'How do I center a text element vertically in a row in jetpack compose? The height of the row is determined by a button next to the text element.'}
>>> {'from': 'gpt', 'value': 'In Jetpack Compose, you can center a text element vertically within a row by using the...'}
>>> {'from': 'human', 'value': 'Add some spacing between the text and the button'}
>>> {'from': 'gpt', 'value': 'To add spacing between the `Text` and `Button` elements within the `Row`, ...'}
>>> {'from': 'human', 'value': 'Instead of using a spacer how do I give some padding to the text element instead?'}
>>> {'from': 'gpt', 'value': 'You can add padding to the `Text` element by using the...'}

🪄 🐝 Previewing with Weights & Biases!

Previewing the dataset in the notebook is fine, but what if we wanted to return and inspect the dataset? What if we want a snippet of this dataset to be integrated into some powerful dashboard for monitoring and tracking?
Weights & Biases has you covered! I'll showcase a simple but effective demo on how to log this dataset.
First, need to initialize the run for a given project. Each project can have multiple runs.
run = wandb.init(
project=wandb_project_name, # Project name.
name="log_dataset", # name of the run within this project.
config={ # Configuration dictionary.
"split": "train"
},
group="dataset", # Group runs. This run belongs in "dataset".
tags=["dataset"], # Tags. More dynamic, low-level grouping.
notes="Logging subset of Puffin dataset." # Description about the run.
) # Check out the other parameters in the `wandb.init`!
Next, we need to figure out how to log the dataset. W&B Tables is a great tool for that, but our datasets have a variable number of turns each instance! What do we do? That's totally fine. We can "flatten" our dataset and log them like so.
ididxfromvalue
880humantext
881gpttext

from and value are self-explanatory as they come from the content within the dictionaries. The id is the "id" assigned to each data point in the dataset. The problem here is that we lack order when it comes to looking at a specific id in the dataset. For data point with id equal to 88, how do we know the order? Simple! We have an index idx to keep track of that order.
data = []
for i in range(1000): # Log 1000 instances.
x = dataset[i]
id_ = x["id"]
conversations = x["conversations"]
for idx, response in enumerate(conversations):
data.append([id_, idx, response["from"], response["value"]])

table = wandb.Table(data=data, columns=["id", "idx", "from", "value"])
run.log({"first1000_Puffin": table})
Here I create a wandb.Table and log it to my run! Note the data that's passed in can include different data types, but I opted for a list of lists, specifically List[List[int, int, str, str]].
Then we can simply call run.finish() to finish the run.
run.finish()
Voilà! Now we have a dashboard for our dataset. We can manually inspect the data if we'd like or, in advanced cases, set up some logic to monitor datasets as they are updated. If you want to see the dashboard itself, check out the project.



🚚Preprocessing

For preprocessing, HuggingFace has already handled most of that for us. Our preprocessing pipeline will be very simple.
1. Format prompts
2. Tokenize, pad, truncate prompts
3. Shuffle
Though optional, it does help to format our prompts in a certain way. I've chosen this format.
Below is a conversation between a user and you.

<human>: <value>
<gpt>: <value>
...

Instruction: Write a response appropriate to the conversation.
We need to write a function to format prompts on a per-sample basis (so that we can map it to our entire dataset).
def format_prompt(sample):
"""Given a sample dictionary with key "conversations", format the conversation into a prompt.

Args:
sample: A sample dictionary from a Hugging Face dataset.

Returns:
sample: sample dictionary with "text" key for the formatted prompt.
"""

INTRO = "Below is a conversation between a user and you."
END = "Instruction: Write a response appropriate to the conversation."

conversations = ""
for response in sample["conversations"]:
from_, value = response["from"], response["value"]
conversations += f"<{from_}>: " + value + "\n"

sample["text"] = "\n\n".join([INTRO, conversations, END])

return sample
An example output is shown below!

Next, we need to get the max length to tokenize the dataset. To do that, we have a simple function to get the max length the model supports! It defaults to 1024 if no max length from the model is found. If you have more compute than you can afford to select a max sequence length fit for your task, but if you're running on Colab's T4, I found 3000 to work fine!
def get_max_length(model):
conf = model.config
max_length = None
for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
max_length = getattr(model.config, length_setting, None)
if max_length:
print(f"Found max length: {max_length}")
break
if not max_length:
max_length = 1024
print(f"Using default max length: {max_length}")
return max_length

# Change the max length depending on hardware constraints.
# max_length = get_max_length(model)
max_length = 3000 # 3000 works for a T4 on Colab.
>>> 3000
Below is an example of the tokenized/encoded output.

Finally, we want to put all this together inside 1 function preprocess_dataset! Again, credit to this article for this code.
# https://github.com/databrickslabs/dolly/blob/master/training/trainer.py
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, dataset: str, seed: int = 42):
# Format each prompt.
print("Preprocessing dataset...")
dataset = dataset.map(format_prompt)

# https://blog.ovhcloud.com/fine-tuning-llama-2-models-using-a-single-gpu-qlora-and-ai-notebooks/
def preprocess_batch(batch, tokenizer, max_length):
return tokenizer(
batch["text"],
max_length=max_length,
truncation=True,
)

# Apply preprocessing to each batch of the dataset & and remove "conversations" and "text" fields.
_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
dataset = dataset.map(
_preprocessing_function,
batched=True,
remove_columns=["conversations", "text"],
)

# Filter out samples that have input_ids exceeding max_length.
# Not needed as the tokenizer truncates all prompts over max length.
# dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length)

# Shuffle dataset.
dataset = dataset.shuffle(seed=seed)

return dataset
I'll keep 2 copies of the dataset. A formatted dataset and a fully preprocessed one. The formatted one will be for visualization on W&B.
formatted_dataset = deepcopy(dataset).map(format_prompt)
dataset = preprocess_dataset(tokenizer, max_length, dataset)

🪄 🐝 Logging + Saving our Preprocessed Dataset to Weights & Biases!

Great! But what if we also wanted to version our dataset? It's convenient to keep all experiment tracking-related things together. So why not? W&B has Artifacts, which is basically an abstract storage unit. We can store datasets, models, files, and anything you could really think of.
Let's log our preprocessed dataset as a table onto W&B.
run = wandb.init(
project=wandb_project_name, # Project name.
name="log_prep_dataset", # name of the run within this project.
config={ # Configuration dictionary.
"split": "train"
},
group="dataset", # Group runs. This run belongs in "dataset".
tags=["dataset"], # Tags. More dynamic, low-level grouping.
notes="Logging preprocessed subset of Puffin dataset." # Description about the run.
) # Check out the other parameters in the `wandb.init`!

data = []
for i in range(1000): # Log 1000 instances.
x = formatted_dataset[i]
id_ = x["id"]
conversation = x["text"]
data.append([id_, conversation])

table = wandb.Table(data=data, columns=["id", "value"])
run.log({"first1000_prep_Puffin": table})
First, we save the HuggingFace dataset to disk. This is stored as a folder. Then we can instantiate our Artifact object, add the directory contents to the artifact, and finally log it to W&B.
dataset.save_to_disk("Puffin_prep.hf")

artifact = wandb.Artifact(name="Puffin_prep", type="dataset")
artifact.add_dir("./Puffin_prep.hf", name="train")
run.log_artifact(artifact)
run.finish()
Check that out! We have just logged both our formatted and unformatted dataset onto Tables and we saved the preprocessed dataset as a dataset Artifact.




👾Training Mistral 7B

Time to train. 😎
To train our Mistral-7B model, we can use HuggingFace's Trainer and TrainerArguments.
training_args = TrainingArguments(
output_dir="outputs",
per_device_train_batch_size=1, # Best practice: https://huggingface.co/docs/transformers/main/main_classes/quantization#tips-and-best-practices
gradient_accumulation_steps=4, # Powers of 2.
learning_rate=2e-4,
max_grad_norm=1.0,
max_steps=40,
lr_scheduler_type="linear",
warmup_steps=5,
fp16=True,
logging_strategy="steps",
logging_steps=1,
save_strategy="steps",
save_steps=10,
optim="paged_adamw_8bit",
report_to="wandb"
)
All of these arguments are mostly understandable. To log to W&B, we specify report_to="wandb". That's it! For more information, click here.
Before we instantiate the Trainer, we must know we are reporting to W&B. So, if you intend to customize the specific run for your current fine-tuning job, you should initialize a run before instantiating the Trainer.
run = wandb.init(
project=wandb_project_name,
name="train_run0", # Sometimes I use the run name as short descriptor for the run.
config={
"split": "train",
# Optionally, you can add all hyperparameters and configs here for better reproducibility!
},
group="train",
tags=["train", "AdamW"], # Add tags for what might characterize this run.
notes="Initial finetuning."
)
# You can call wandb.init before instantiating the `Trainer` to customize your run!
Let's instantiate the Trainer now!
trainer = Trainer(
model=model,
args=training_args,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
train_dataset=dataset,
)
Pretty simple, right? The DataCollatorForLanguageModeling is just:
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they are not all of the same length.

💭Recap Before Training

Dataset:

  • Puffin dataset
  • formatted prompts to our prompt format
  • tokenized, padded, truncated the data
  • shuffled the data

Model:

  • tokenizer pad_token is the eos_token
  • BitsAndBytesConfig
    • load_in_4bit = True
    • bnb_4bit_use_double_quant = True
    • bnb_4bit_quant_type = "nf4"
    • bnb_4bit_compute_dtype = torch.bfloat16
  • disabled model cache
  • prepare_model_for_kbit_training
    • freezed model weights
    • layer norm and lm head -> fp32
    • enabled gradients for input embeddings
    • enabled gradient checkpointing
  • LoraConfig
    • lora_alpha = 16
    • lora_dropout = 0.1
    • lora_r = 64
    • applied to q, k, o, v, gated, up, down proj layers

Training:

  • batch size = 1
  • gradient accumulation = 4
  • lr = 2e-4
  • max_grad_norm = 1.0
  • lr_scheduler_type = "linear"
  • warmup_steps = 2
  • fp16 = True
  • optim = "paged_adamw_8bit"
  • all other parameters left to default
  • DataCollatorForLanguageModeling

Now that you know exactly what we've done to get to this step in the fine-tuning process, let's run the trainer.
results = trainer.train() # Now we just run train()!
run.finish()
You can see various metrics and statistics are automatically logged behind-the-scenes to W&B.


HuggingFace's Trainer will also save the model for you to Artifacts.


🚀Inference Using Mistral 7B

Performing inference is a straightforward process. We load in the base model, combine it with our adapter weights (saved in the W&B artifacts for our project), and then run inference as normal.
model_name = "someone13574/Mistral-7B-v0.1-sharded"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

# You can just use model.
inf_model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
)
For the next step (loading in the artifact from W&B), copy the following code from your project run's location for the saved model. Because artifacts are versioned, depending on your needs, you can load in the artifact version you desire.

import wandb
run = wandb.init(project=wandb_project_name) # MAKE SURE TO PASS IN YOUR PROJECT NAME!
artifact = run.use_artifact('vincenttu/finetuning_mistral7b/model-t6rw0dav:v0', type='model')
artifact_dir = artifact.download()
run.finish()
Finally, let's get_peft_model. Make sure your path is correct here!!
model = PeftModel.from_pretrained(inf_model, "/content/artifacts/model-t6rw0dav:v0")
Let's run a sample prompt to make sure it works.
prompt = "What is a neural network??"

device = "cuda" if torch.cuda.is_available() else "cpu"
model_input = tokenizer(prompt, return_tensors="pt").to(device)

_ = model.eval()
with torch.no_grad():
out = model.generate(**model_input, max_new_tokens=100)

print(tokenizer.decode(out[0], skip_special_tokens=True))


👋Conclusion

🥳🥳Congrats! You just made it to the end. This article is intended as an introduction to fine-tuning Mistral 7B with HuggingFace. If you want a very in-depth guide, do check out my companion notebook. I've learned much from exploring LLM fine-tuning for this article and I hope you did too! Thanks for reading! 👋😎

References

Tags: Articles, LLM
Iterate on AI agents and models faster. Try Weights & Biases today.