RLHF: Hyperparameter Optimization for trlX
In this article, we follow a step-by-step guide for performing hyperparameter optimization using Ray Tune and Weights & Biases, looking at how trlX can help.
Created on November 16|Last edited on March 3
Comment
In this article, we're going to look at how Transfer Reinforcement Learning X (trlX) from CarperAI—and other new techniques in reinforcement learning—can help create better, safer language models through Reinforcement Learning from Human Feedback (RLHF). We'll start with a brief introduction to RLHF and why it's important before digging into how it's actually done.
Table of Contents
Why is Reinforcement Learning From Human Feedback (RLHF) Important? What is Transfer Reinforcement Learning X (trlX)?Hyperparameter OptimizationConclusion
Why is Reinforcement Learning From Human Feedback (RLHF) Important?
Reinforcement Learning from Human Feedback (RLHF) is a critical field of research for creating safe, responsible ML models. We want our big models to align with human preferences rather than imitate human behavior. In other words, we want these models to have an abstract notion of certain preferences.
For example, having an abstract notion of "do not create NSFW content", "do not generate profanity", "do not lie", etc would be incredibly useful. This work was pioneered by OpenAI and has been used to train language models that are much better at following user intentions than GPT-3 while also making them more truthful and less toxic.
We have recently seen massive community-driven efforts to reimplement these models so everyone can benefit from them. Big models are getting truly democratized––by the people and for the people. One such community-driven research lab is CarperAI. It is the newest lab within the EleutherAI research collective, focusing on improving the performance and safety of large language models (LLMs) with reinforcement learning.
What is Transfer Reinforcement Learning X (trlX)?
Transfer Reinforcement Learning X (trlX) is a repo to help facilitate the training of language models with Reinforcement Learning via Human Feedback (RLHF) developed by CarperAI. trlX allows you to fine-tune HuggingFace-supported language models such as GPT2, GPT-J, GPT-Neo and GPT-NeoX based.
It allows for up to 20B parameters using reinforcement learning via either a provided reward function or a reward-labelled dataset. Proximal Policy Optimization (PPO) and Implicit Language Q-Learning (ILQL) are implemented here to achieve this aim.
💡
Running an Example Script
trlX comes with multiple examples that can help you get started using this library and from which you can take inspiration to build your pipelines. Each example script has an associated config file in the configs dir. Still, it always helps to center ourselves on a real-world project, so let's look at the ppo_sentiments example.
Specifically, this example trains to generate movie reviews with positive sentiment in an online setting. It samples reviews from a model fine-tuned on the IMDB dataset and rating samples with the learned sentiment reward model.
Run set
1
Hyperparameter Optimization
Now that we have some context on trlX and a baseline on ppo_sentiments, let's dig into how we can perform hyperparameter optimization for trlX. Few disclaimers:
- Since trlX is a relatively new library, the hyperparameter optimization workflow is a work in progress too. We want to make it seamless to use.
- The workflow might change depending on user feedback, but this report will reflect the latest changes.
trlX supports distributed and scalable hyperparameter optimization using Ray Tune and Weights & Biases. Next, we're going to walk through, step-by-step, how to use all three tools in tandem. (If you are new to Ray Tune, check out the Getting Started page to learn more.)
Search Space
Every hyperparameter optimization tool will require you to set up a search space to optimize from. The search space for trlX is defined using a .yml config file inside configs/sweeps directory.
- You will find an associated config file inside the configs dir for a given example task (you will have to create one for your custom task). In our example, the associated config file is configs/ppo_config.yml. This is default config for our search space.
- Create a new .yml file inside configs/sweeps and name it ppo_sweep.yml. Here you will define the search space as shown in the code snippet below:
lr_init:strategy: "loguniform"values: [0.00001, 0.01]init_kl_coef:strategy: "uniform"values: [0, 0.2]vf_coef:strategy: "uniform"values: [0.5, 2]
The hyperparameters that you want to define a search space for constitute the top level parameters in the .yml file.
Strategy and Values
The top-level hyperparameter is followed by strategy and values that are used to define the sampling type and the values for these sampling strategies.
Check out the list of strategies available to you below. One can also look at the get_strategy function here to learn more about it:
strategy | Definition |
---|---|
uniform | Samples uniformly between the given bounds |
quniform | Samples uniformly between the given bounds, quantized. |
loguniform | Samples uniformly between the given bounds on a log scale. |
qloguniform | Samples uniformly between the given bounds on a log scale, quantized. |
randn | Samples from a normal distribution. |
qrandn | Samples from a normal distribution, quantized. |
randint | Samples uniformly between the given bounds, quantized to integers. |
qrandint | Samples uniformly between the given bounds, quantized to integers. |
lograndint | Samples uniformly between the given bounds on a log scale, quantized to integers. |
qlograndint | Samples uniformly between the given bounds on a log scale, quantized to integers. |
choice | Samples from a discrete set of values. |
qrandn | Samples from a normal distribution, quantized. |
grid_search | Samples from the given list of values. |
values in essence are the arguments for the strategy. Check out the Ray Tune's search space APIs here to learn about the acceptable values for each strategy.
💡
Tune Config
Now that the search space for your hyperparameters is defined, let's quickly add a tune_config to our configs/sweeps/ppo_sweep.yml. This is used to define the metrics for which we are optimizing our hyperparameters, the search algorithm, and the scheduler.
An example tune config is shown below:
tune_config:mode: "max"metric: "mean_reward"search_alg: "random"scheduler: "fifo"num_samples: 32
In our example, we are trying to maximize the mean_reward metric. num_samples determines the number of times hyperparameters should be sampled from the hyperparameter space.
Search Algorithm and Scheduler
Ray Tune supports multiple search algorithms (you can check those out here). Currently, trlX supports Bayesian Optimization HyperBand (BOHB) and random search.
For using BOHB the argument for search_alg should be "bohb". Importantly, BOHB is intended to be paired with a specific scheduler class called HyperBandForBOHB. We can use this by passing "hyperbandforbohb" to scheduler.
To use BOHB you will have to install these libraries: pip install hpbandster ConfigSpace
💡
You can also perform "random" search with "fifo" (first in first out) scheduler. However, it's recommended that you use bayesian optimization with hyperband to reduce the time spent on optimization. This is because trials (a.k.a. experiments) that will not reach a good reward value are terminated automatically by the scheduler.
The metrics generated while training the models are logged as shown in the panels below. These metrics were generated using random search and fifo scheduler:
Train Function
Now that we have defined the search space, we need a model training function that can take a config dictionary as an argument. The hyperparameters are sampled from the search space and passed to the training function. The training logic is implemented inside this function and communicates with Ray Tune by logging the metrics using ray.air.ession.report.
💡
In trlX, the provided examples/ act as the training functions. You simply need to pass the path to the example script and everything will be taken care of under the hood.
Running Hyperparameter Optimization
Everything is stitched together by trlx/sweep.py file. It requires few argparse as listed below:
argparse | Definition |
---|---|
script | Path to the example script situated at examples/ . This acts as the training function for Ray Tune. |
config | The config file where param space and tune configs are defined. It's configs/sweeps/ppo_sweep.yml in our example. |
num-cpus | Use this to allocate number of CPUs cores to each trial. |
num-gpus | Use this to allocate number of GPUs cores to each trial. |
server-address | Use this to launch hyperparameter optimization on a cluster. Learn more about it here. |
You can run the command as shown below:
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py --num_cpus 2 --num-gpus 1
This will allocate 2 CPU cores and 1 GPU per experiments. Currently distributed training is not supported and for that reason using only one GPU per trial is recommended.
If you face issues or have ideas to improve the interface please raise an issue in the CarperAI/trlx repo.
💡
Using Weights & Biases to Visualize your Hyperparameters
The metrics generated during each trial are saved as .json files in the local directory ray_results/. Once all the trials are done, the log_trials function logs all the metrics to your Weights & Biases project workspace.
You can control the project name from the configs/ppo_sentiments.py project_name argument.
Since the metrics are logged after Ray Tune trials are done, we are not logging the metrics in real time. Due to this you don't have access to system metrics. This was done to decouple W&B from Ray Tune and try out multiple search algorithms as few search algorithms weren't playing well with W&B. We will in a future iteration work to log the metrics in real time.
💡
With the Weights & Biases integration, you have access to the training metrics along with the hyperparameters used for that experiment.
Parallel Coordinate Plot
Parallel coordinates chart summarizes the relationship between large numbers of hyperparameters and model metrics at a glance. Plus, they look rather nice. You can learn more here but here's a look at how our hyperparameters affect our reward:
Run set
32
Parameter Importance
The parameter importance plot lists the hyperparameters that were the best predictors of desirable metric values. In other words, which parameters were most important (hence the name). You can learn more here.
From the example parameter importance chart shown below:
- The lr_init is the most important hyperparameter but negatively correlates with mean_reward. It makes sense because lowering the learning rate helps achieve better convergence usually.
- The num_rollouts has a positive correlation with mean_reward indicating if more episodes are shown to the model the better it performs.
You can use these insights to do granular hyperparameter optimization.
Run set
33
Scatter Plot
The scatter plot compares different trials and gives you insight into how the trials progressed. You can learn more about it here.
Run set
33
Automatic W&B Analysis Template
The interface automatically generates a W&B report with parallel coordinate plots, parameter importance plot, scatter plot, metrics and the best config using W&B Reports API. The logic for report generation can be found in trlx/ray_tune/wandb.py. Each generated report has a unique name identified with Ray Tune's ID. Some notes:
- As a user, you will not have to create the charts shown above manually.
- You can share the report with your teammates and colleagues to start a discussion.
- Share it with the community to get feedback. Users without edit access to the report can comment on any section of the report.
An example report is shown below:
Conclusion
This report document the steps required to start using Hyperparameter Optimization for trlX. This integration is meant for researchers and users and thus we want to make it as seamless as possible. Currently, examples/ppo_sentiments.py and examples/ilql_sentiments.py are supported by default but other tasks can be easily added by following the steps documented above. We hope you will find it useful.
If you find any issues or face difficulty using it please raise an issue in CarperAI/trlx repo and tag @ayulockin. Feel free to join the CarperAI's discord to discuss more on how to RLHF or to help them build open models.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.