Fixing my fine-tuning
How do I inspect what's wrong on my fine-tuning. A hands on approach!
Created on March 7|Last edited on March 7
Comment
After succesufully* finetuning Gemma with the new Zephyr recipe from the H4 Hugginface team, I decided to give old-trusty Mistral-7B a new fine-tune. As you can see on the loss and gradients norm curves it didn't work as expected. This spikes on the loss and gradients happens sometimes and most of the time you would just restart, adjust some hyperparameters and call it a day. In my case, even reducing learning rate, increasing batch_size and enabling gradient clipping was not enough, and more importantly, I want to know what's going on!
Run set
3
Fine-tuning recipe and telemetry
In this experiments we are using the fine-tuning recipe from the Huggingface Alignment Handbook repo. This is the code that used to produce the Zephyr Mistral and Gemma models. The code is very straightforward to use, leveraging TRL and transformers libraries. There is one script for the supervised finetuned (SFT) phase and one for the direct preference optimization (DPO). Both scripts are similar and very readable, loading the data from the hub, setting up the model configuration and launching training with the corresponding trainers (SFTTrainer and DPOTrainer). As both of this trainers subclass the original transformers.Trainer we get W&B logging built-in.
This repo is setup in a way that all the config is passed with a "recipe" in the form of a YAML file. You can find the different recipes on the recipes folder, we have recipes for full fine-tunes and qLoRA.
You only need to add the report_to: wandb the appropiate config and you are good to go, your metrics will be streamed to W&B.
💡
Checking the model's inputs when spike
What I would like to know, is what data is producing this sudden spike in loss and gradients. I want to see the batch that was fed to the model and maybe there is something fishy with it. Maybe I am doing something wrong on my data processing side? who knows!
This is easier said than done, as the model is iterating on the data on multiple GPUs, packing the dataset into multiple sequences and averaging (gathering) the loss to the rank0 process and reporting that back to W&B.
Add a comment