An Example of Transformer Reinforcement Learning
In this article, we take a look at the logged metrics and gradients from a GPT-2 experiment that is tasked with writing favorable reviews for movies.
Created on May 13|Last edited on October 14
Comment
Overview
Transformer Reinforcement Learning is a library for training transformer language models with Proximal Policy Optimization (PPO), built on top of Hugging Face.
In this article you'll be able to see logged metrics and gradients from an example project— a GPT-2 experiment fine-tuning the model to generate positive movie reviews. The language model takes a few words from a movie review as input and is tasked with finishing the review with a positive tone, using a sentiment classifier to calculate the reward.
Here's an overview of the training cycle. At each epoch the model runs through each step, generating and evaluating 256 reviews. The whole process takes about 2 hours on a P6000 GPU.

Rewards
In contrast to supervised learning, reinforcement learning uses a scalar reward signal instead of explicit labels. The model figures out how to adjust its outputs to generate more rewards in the future, so it's useful to log rewards over time.
The line plot here shows how reward increases over training, meaning it is learning to successfully generate positive movie reviews. The heatmap shows how the distribution of rewards develops over time. In the first 10 steps, the model is generating a broad range of mediocre reviews. As training continues, you can see a solid band of positive rewards appearing at the top of the heatmap.
Example experiment
1
Model Outputs
At each timestep a table with the model's queries and responses as well as the rewards are logged. The queries are text snippets sampled from the IMDB dataset and the responses are the text generations of the language model conditioned on the queries.
The "positivity" of the generated review is determined using a sentiment classifier. This is a BERT model trained to predict sentiment on the IMDB dataset.
This table helps debug the language model's text generation and monitor the training progress qualitatively. Also, reading the continuations is a good distraction while waiting for the model to finish training.
Example experiment
1
KL Control
The output of the sentiment classifier is not the only objective for the PPO optimisation: also the KL-divergence to the reference model (the same model before training) is used. To avoid the text generation deviating too much from the reference model the KL-divergence is subtracted from the returns. The resulting term is here referred to as the return.
The KL-divergence is scaled dynamically to achieve a predefined target divergence. The subtraction coefficient is called kl_coef and the product with the KL-divergence is the non_score_reward. The equation is:
Example experiment
1
Log Probabilities
As mentioned before it is important that the trained language model does not deviate too far from the original model. Thus, we monitor the log probabilities of the output tokens and compare the difference which gives a proxy for the KL-divergence.
One can observe that the model stays reasonably close to its original distribution. There were a few issues related to PPO exploiting the text generation heuristics in the transformers library and these plots helped investigate the issues. An in-depth description of the issues can be found here.
Example experiment
1
Timing
Looking and the times for each part of the training loop we can see that most of the total time is used for the optimisation step only followed by the text generation step that creates the response to the query.
That graph was useful in the development process to figure out which part of the code needed optimisation. At the moment the main bottleneck is the PPO optimisation and more specifically the backward pass which is to be expected.
Example experiment
1
Add a comment
Tags: Intermediate, NLP, Reinforcement Learning, HuggingFace, Experiment, Panels, Plots, Slider, Sweeps
Iterate on AI agents and models faster. Try Weights & Biases today.