Skip to main content

How to Fine-Tune LLaVA on a Custom Dataset

A tutorial for fine-tuning LLaVA on your own data!
Created on January 27|Last edited on February 9

Introduction

In mid-2023, LLaVA emerged as a groundbreaking multimodal language model, showcasing an advanced approach integrating language and visual data. Unlike traditional models that primarily focus on either text or image processing, LLaVA stands out for its ability to seamlessly blend both domains. This enables the model to understand and interpret the intricate relationship between visual elements and textual descriptions, leading to more nuanced and contextually rich AI interactions.
In this tutorial, we will be covering how to fine-tune this multi-modal model on your own custom dataset!


What We'll Cover



How Does LLaVa Work?

Architecturally, LLaVA unites the strengths of pre-trained language models like Vicuna or LLaMA with visual models like CLIP's visual encoder. This integration involves transforming the visual features extracted from images into a format that aligns with the language model’s embeddings.
The authors of LLaVA also introduced a visual instruction tuning process, which has proven to be a pioneering approach for multimodal AI. They utilize GPT-4, a language-only model, to generate instruction-following data that pairs language with images. This innovative method involves converting image-text pairs into formats suitable for instruction-following tasks, effectively creating a bridge between visual data and language processing.
Essentially, they use existing datasets containing text-image pairs and prompt GPT-4 to generate more detailed text data based on the of the text-only label data. By employing existing datasets composed of text-image pairs, they initiate a process where GPT-4—inherently a text-based model—is prompted to elaborate on the text-only label data associated with each image. This procedure transforms the original dataset into a more complex and instruction-rich version. GPT-4 generates an array of questions and detailed descriptions based on the initial image captions, effectively deepening the contextual understanding and expanding the instructional content of the data.
This expansion is not just an increase in the volume of text but an enhancement in the quality and depth of information. The language model delves into the nuances of each image, asking pertinent questions and providing detailed descriptions that go beyond the surface level. This results in a dataset that is richer and more suited for training an AI model capable of nuanced multimodal understanding and response.
Here's an example of the data used for training the LLaVA model:


The Architecture

Architecturally, LLaVA unites the strengths of pre-trained language models like Vicuna or LLaMA with visual models like CLIP's visual encoder. This integration involves transforming the visual features extracted from images into a format that aligns with the language model’s embeddings. The model employs a trainable projection matrix for this purpose, resulting in a sequence of visual token embeddings that are compatible with the language model.


LLaVA's training comprises a two-stage process. The initial stage, referred to as pre-training, utilizes image-text pairs to align the visual features with the language model's embeddings. This stage keeps the weights of both the visual encoder and language model frozen, focusing on training the projection matrix. The subsequent stage involves fine-tuning the model end-to-end. Here, the visual encoder's weights are frozen, while updates are made to the projection layer and language model.

The Data for Our Tutorial

For this experiment, we'll focus on fine-tuning LLaVA on a custom dataset using the official LLaVA repo with the Llama-2 7B backbone language model. We will use the OK-VQA dataset, which contains image text pairs that involve reasoning to answer questions about images. For example, instead of simply asking the model to describe the image, specific questions are asked about the image, that relate to its contents.
For fine-tuning LLaVA on the OK-VQA dataset, we must first format the data to align with the specific requirements of the LLaVA repository. The OK-VQA dataset presents a unique challenge with its focus on complex reasoning tasks, involving image-text pairs with questions that go beyond simple image descriptions. These questions require deeper cognitive processing, making it a suitable choice for testing LLaVA's advanced capabilities.
In order to convert the dataset into a format suitable for use with the official LLaVA repo, we will need to write a python script to covert our data into the following format:
[
{
"id": "unique_id",
"image": "image_file.jpg",
"conversations": [
{
"from": "human",
"value": "What is shown in the image?"
},
{
"from": "gpt",
"value": "formatted_answers"
}
]
}
]

This script uses the datasets library from HuggingFace to load and filter the OK-VQA dataset. Specifically we'll focus on the 'other' class from the dataset.
from datasets import load_dataset
from PIL import Image
from io import BytesIO
import requests
import os
import json
import uuid


def process_and_save(dataset, output_folder, subset_name):
# Define image subfolder within output folder
subset_folder = os.path.join(output_folder, subset_name)
image_subfolder = os.path.join(output_folder, 'images')

if not os.path.exists(image_subfolder):
os.makedirs(image_subfolder)

if not os.path.exists(subset_folder):
os.makedirs(subset_folder)

# Initialize list to hold all JSON data
json_data_list = []

# Process and save images and labels
for item in dataset:
# Load image if it's a URL or a file path
if isinstance(item['image'], str):
response = requests.get(item['image'])
image = Image.open(BytesIO(response.content))
else:
image = item['image'] # Assuming it's a PIL.Image object

# Create a unique ID for each image
unique_id = str(uuid.uuid4())

# Define image path
image_path = os.path.join(image_subfolder, f"{unique_id}.jpg")

# Save image
image.save(image_path)

# Remove duplicates and format answers
answers = item['answers']
unique_answers = list(set(answers))
formatted_answers = ", ".join(unique_answers)

# Structure for LLaVA JSON
json_data = {
"id": unique_id,
"image": f"{unique_id}.jpg",
"conversations": [
{
"from": "human",
"value": item['question']
},
{
"from": "gpt",
"value": formatted_answers
}
]
}

# Append to list
json_data_list.append(json_data)

# Save the JSON data list to a file
json_output_path = os.path.join(output_folder, subset_name, 'dataset.json')
with open(json_output_path, 'w') as json_file:
json.dump(json_data_list, json_file, indent=4)

def save_dataset(dataset_name, output_folder, class_name, subset_name, val_samples=None):
# Load the dataset from Hugging Face
dataset = load_dataset(dataset_name, split=subset_name)

# Filter for images with the specified class in 'question_type'
filtered_dataset = [item for item in dataset if item['question_type'] == class_name]

# Determine the split for training and validation
if val_samples is not None and subset_name == 'train':
train_dataset = filtered_dataset[val_samples:]
val_dataset = filtered_dataset[:val_samples]
else:
train_dataset = filtered_dataset
val_dataset = []

# Process and save the datasets
for subset, data in [('train', train_dataset), ('validation', val_dataset)]:
if data:
process_and_save(data, output_folder, subset)


# Usage example
output_folder = 'dataset'
class_name = 'other'
val_samples = 300
save_dataset('Multimodal-Fatima/OK-VQA_train', output_folder, class_name, 'train', val_samples)
save_dataset('Multimodal-Fatima/OK-VQA_test', output_folder, class_name, 'test')
The script processes each image and its associated question from the dataset, saves the images locally and creates a unique identifier for each. The questions and answers are formatted into a single JSON file. In this structure, the 'human' key represents the person asking the question, and the 'gpt' key represents LLaVA's response. The JSON format is crucial as it matches the expected input format for LLaVA, enabling effective training and fine-tuning of the model.
Note that we will not follow the same instruction tuning process as demonstrated in the paper, and we will mainly focus on training the model to do single response 'complex reasoning' given an image and a query.

Training

Now that the dataset is formatted and ready, we move on to the training phase of LLaVA. We will build off of the original LLAVa repo. Notably, the original repository for LLaVA lacked features for intermediate evaluations in between epochs, which is helpful for identifying signs of overfitting. To address this gap and enhance the training process, I added functionality for periodic evaluations. This version of the training script, along with other modifications made for this project, can be found in the project repository.

Downloading the Pre-trained Weights

In order to download the weights, you can use the following commands:
git lfs install
git clone https://huggingface.co/liuhaotian/llava-v1.5-7b

Enhancing Training Efficiency with Q-Lora

Training large language models typically presents a challenging trade-off between computational efficiency and model performance. Traditionally, you're either faced with utilizing vast computational resources for training large models or accepting diminished performance with smaller models. However, there is an approach that reconciles these conflicting demands: Q-Lora.

Understanding Lora and QLoRA

To grasp the essence of QLoRA (Quantized Lora), it's essential to first understand the concept of LoRA. LoRA's strategy involves maintaining the original pre-trained backbone of the model intact while appending additional, more efficiently trainable layers. This approach facilitates rapid adaptation to new tasks without the need for retraining the entire network. By concentrating the learning on a select group of new parameters, LoRA effectively retains the benefits of a substantial pre-trained model but with significantly reduced computational demands. This aspect is particularly beneficial in practical scenarios where resources are constrained or swift adaptation to novel data is paramount. QLoRA introduced a novel data type, the 4-bit NormalFloat, specifically designed for normally distributed weights, which surpasses the performance of other 4-bit data types. This new 4-bit NormalFloat reduces computational requirements even further!

DeepSpeed

DeepSpeed is an open-source deep learning optimization library designed to enhance the speed, scale, and efficiency of training large-scale deep learning models. Developed by Microsoft, it allows for faster and more efficient training, particularly for very large models, by leveraging various optimization techniques.
One of the key components of DeepSpeed is its ZeRO technology. ZeRO is designed to optimize the memory usage during training, enabling the training of much larger models than was previously possible on the same hardware. ZeRO is divided into different optimization stages, with ZeRO Stage 2 being one of them. ZeRO Stage 2 reduces memory redundancy by partitioning optimizer state, gradients, and parameters across the data parallel processes. This means each process stores only a portion of these components, drastically reducing the memory requirements for each process. If you experience CUDA memory errors with this config, consider trying the stage 3 config, which allows for offloading gradients to the CPU, which will slow down training, but may solve the memory error.

The Train Command

I won’t go into the details of the training script, however, I will cover the run command for using the script, as many of the details can be covered easily here. Generally, instead of pasting long command like this into the terminal, I prefer to create a bash script (with the .sh extension) and place the command in this file. I've found this makes things easier when testing out different hyper-parameters as well as avoiding syntax errors in the command line.
The training script, train.py, is executed with the following command:
#!/bin/bash


# Set the prompt and model versions directly in the command
deepspeed /root/LLaVA/llava/train/train_mem.py \
--deepspeed /root/LLaVA/scripts/zero2.json \
--lora_enable True \
--lora_r 128 \
--lora_alpha 256 \
--mm_projector_lr 2e-5 \
--bits 4 \
--model_name_or_path /root/LLaVA/llava/llava-v1.5-7b \
--version llava_llama_2 \
--data_path /root/dataset/train/dataset.json \
--validation_data_path /root/dataset/validation/dataset.json \
--image_folder /root/dataset/images/ \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir /root/LLaVA/llava/checkpoints/llama-2-7b-chat-task-qlora \
--num_train_epochs 500 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--gradient_accumulation_steps 1 \
--evaluation_strategy “epoch” \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb


You can copy this run command into a file called run.sh and then run the command sh -x run.sh and this will start the training script!
Key parameters in the command include:
  • mm_projector_lr: The separate learning rate for the multimodal projector as specified by the LLaVA authors
  • bits: This is where we specify we want to use Q-LoRA
  • lora_alpha: Following the guidelines of the LLaVA authors, we've set lora_alpha to 256. This alpha value is pivotal in preserving numerical stability and the full expressive power of the model. It's worth noting that this is an adjustment from the typical values around 16
  • lora_r: The lora_r parameter represents the rank of the decomposition matrices in LoRA. We've chosen a value of 128, diverging from the common range of 8 to 64 seen in typical LLM fine-tunes. A higher rank, as in our case, can enhance the model's representational capability
  • mm_projector_type: I set this to mlp2x_gelu, which is a multi-layer perceptron with GELU activation
  • deepspeed: Here we specify the deepspeed zero stage 2 config for the training run
  • data_path: This parameter specifies the location of the training dataset that we created earlier
  • validation_data_path: Since I added intermediate evaluations between each epoch, we will need to pass the path to our validation dataset as well (note that the code assumes the images for both train and validation are in the same directory)
  • image_folder: This argument points to the directory containing the images used in both the training and validation datasets.
  • output_dir: This is the directory where the trained model checkpoints will be saved. It’s important to have sufficient storage space in this directory, especially when training large models like LLaVA
Depending on your hardware setup, you can change the batch size to avoid memory errors. I trained on 8 NVIDIA RTX 3090’s, which had no issues with a batch size of 32. The training script has an option for monitoring using Weights & Biases using the --report_to wandb flag, providing real-time tracking of the model's progress and performance metrics.
After about 10 epochs, the loss for the train and validation set started to level off, so I decided to stop the training run a little early. Here are my training logs from the run:

Run: frosty-river-92
1

When using the LLaVA repo, the output folder for your lora adapter must contain the string “llava” and “lora” in order to be effectively used by the run_llava script which is used to test out the model for inference.

Running Inference

Now that we have trained our model, it's time to test out the model! The LLaVA repo provides a script called run_llava.py which can be used to easily test out our fine-tuned model. I found an interesting image in the validation set that we can test out the model with:

The following command passes the model path, image path, and query to the script, which will merge the lora weights with the base model, and run inference using the supplied arguments. In order to ask LLaVA a question about the image, we can run the following command:
python run_llava.py --model-path /root/LLaVA/llava/checkpoints/llava-2-7b-chat-task-qlora/best_llava_eval_model_llava_lora
--model-base /root/LLaVA/llava/llava-v1.5-7b
--image-file /root/dataset/images/0f47c0b5-2c77-45e6-87b0-89af46e99500.jpg
--query “why was this photo taken?”
This will take the Lora adapter you trained and merge it with the base model, and then run inference given the image and the query. Below is the response from the model:


Overall

As we conclude this tutorial on fine-tuning the LLaVA model, it's evident that the journey of integrating language and visual data through AI has taken a significant leap forward. LLaVA, with its innovative approach of visual instruction tuning using GPT-4, has not only bridged the gap between textual and visual understanding but also expanded the horizons of multimodal AI capabilities. I hope you have enjoyed this tutorial! If you have any questions, feel free to ask in the comments below!

More Reports on Fine-Tuning




Further Reading and Sources

OK-VQA Datasets:




Muhammad Monjur Karim
Muhammad Monjur Karim •  
Thank you for providing this excellent tutorial. Could you please clarify the GPU requirements for training with my custom dataset?
1 reply
Iterate on AI agents and models faster. Try Weights & Biases today.