Skip to main content

Learning to Summarize using trlx

Implement text summarization from human feedback based on Carper's trlx framework - based on OpenAI's paper
Created on December 25|Last edited on December 27

Dataset

We use the TL;DR summarization dataset from Learning to summarize from human feedback paper.
The dataset contains 129722 posts with a hold out ~5% as a validation set
  • SFT / Policy:
    • Train: 116722 samples
    • Dev: 6447 samples
    • Test: 6553 samples
Example:
Noted: we used the dev dataset as new prompts data to train PPO. Selected final model based on mean reward of dev dataset.
{'id': 't3_1hxu8s',
'subreddit': 'relationships',
'title': 'I (f/22) have to figure out if I want to still know these girls or not and would hate to sound insulting',
'post': "Not sure if this belongs here but it's worth a try. \n\nBackstory:\nWhen I (f/22) went through my first real breakup 2 years ago because he needed space after a year of dating roand it effected me more than I thought. It was a horrible time in my life due to living with my mother and finally having the chance to cut her out of my life. I can admit because of it was an emotional wreck and this guy was stable and didn't know how to deal with me. We ended by him avoiding for a month or so after going to a festival with my friends. When I think back I wish he just ended. So after he ended it added my depression I suffered but my friends helped me through it and I got rid of everything from him along with cutting contact. \n\nNow: Its been almost 3 years now and I've gotten better after counselling and mild anti depressants. My mother has been out of my life since then so there's been alot of progress. Being stronger after learning some lessons there been more insight about that time of my life but when I see him or a picture everything comes back. The emotions and memories bring me back down. \n\nHis friends (both girls) are on my facebook because we get along well which is hard to find and I know they'll always have his back. But seeing him in a picture or talking to him at a convention having a conversation is tough. Crying confront of my current boyfriend is something I want to avoid. \n\nSo I've been thinking that I have to cut contact with these girls because it's time to move on because it's healthier. It's best to avoid him as well. But will they be insulted? Will they accept it? Is there going to be awkwardness? I'm not sure if it's the right to do and could use some outside opinions.",
'summary': "I still have contact with an old ex's friends but can't stand to see or talk to him. His friends are really nice ,so how do I tell them I possibly want to unfriend them on Facebook because of him?"}
  • Reward model comparison dataset
    • Train: 92858 samples
    • Dev: 83797 samples
Example:
{'info': {'id': 't3_26qoai',
'post': 'I [25M] have snooped in the past and copped up to it to my gf [25F] of 6 years. We talked it through. It had been a year or two since the last time. That\'s an issue I\'m working on.\n\nNow she has a new close male work friend. I won\'t go into details, but she hides things from me with him and does other things to make me a bit suspicious. So...I snooped again, and this time, all texts from her new friend have been deleted and I saw a google search for "how to get over a guy" near some searches of his name and views of his Facebook profile.\n\nI asked her about this guy, not mentioning the snooping, and she denied any feelings, we talked for a long time about our relationship and she insisted that she only loves me and I mean the world to her, and that she really wants to work towards getting this relationship back out of the rut we\'ve been in (we both work all the time and barely see each other).\n\nI think if I cop to the snooping, we might have a more honest conversation about what\'s actually going on (if something is) and why she\'s having these feelings so we can either work through it together (my preference) or move on. But obviously, it will open the pandora\'s box of the snooping.\n\nThink it\'s worth it to admit to the snooping to hopefully get to the bottom of this?',
'title': 'To admit or not to admit snooping...',
'subreddit': 'relationships'},
'split': 'train',
'summaries': [{'text': ' Snooped, found something, should I admit what I found so we can have a more honest conversation about it with less denial on her part?',
'policy': 'ref',
'note': ''},
{'text': " I snooped, we talked about it, she wants to work it out, I'm not sure. Is the snooping worth it?",
'policy': 'sup2',
'note': "Worth it in what sense? That it's finally out in the open?"}],
'choice': 0,
'worker': 'KZL1qeRzHNYSfDAuOctL1iyVV8WC5N',
'batch': 'batch4',
'extra': {}}

Training Process:

  • Step 1: Using train TL;DR dataset to train a supervised fine-tuning (SFT) model based on GPT-J 6B architecture.
  • Step 2: Train reward model based on SFT

  • Step 3: Train policy model using learned reward model with PPO algorithm (trlx)


Reward Model

  • Implement Reward model:
import torch
from torch import nn
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
class GPTRewardModel(nn.Module):
def __init__(self, config):
super().__init__()
model = AutoModelForCausalLM.from_pretrained(config)
self.config = model.config
# gpt-neo models have hidden_size instead of n_embd
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
self.transformer = model.transformer
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
def forward(
self,
input_ids=None,
attention_mask=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
)
hidden_states = transformer_outputs[0]
rewards = self.v_head(hidden_states).squeeze(-1)
reward_scores = []
bs = input_ids.shape[0] // 2
chosen = input_ids[:bs]
rejected = input_ids[bs:]
chosen_rewards = rewards[:bs]
rejected_rewards = rewards[bs:]
# compute pairwise loss. Only backprop on last value before padding
loss = 0
for i in range(bs):
# Retrieve first index where trajectories diverge
divergence_ind = (chosen[i] != rejected[i]).nonzero()[0]
assert divergence_ind > 0
# Check if there is any padding otherwise take length of sequence
c_inds = (chosen[i] == self.PAD_ID).nonzero()
c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1]
r_inds = (rejected[i] == self.PAD_ID).nonzero()
r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1]
end_ind = max(c_ind, r_ind)
# Index into correct reward
c_truncated_reward = chosen_rewards[i][divergence_ind : end_ind]
r_truncated_reward = rejected_rewards[i][divergence_ind : end_ind]
reward_scores.append(c_truncated_reward[-1])
loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()
loss = loss / bs
return {
"loss": loss,
"chosen_end_scores": torch.stack(reward_scores)
}
  • Accuracy in reward modeling:

  • Normalize rewards: f(x)=rm(x)rm(xref)f(x) = rm(x) - rm(x_{ref}) 
At the end of training, we normalize the reward model outputs such that the reference summaries from our dataset achieve a mean score of 0.
Result:
Accuracy of validation comparisons dataset

RLHF with trlx

Hyrda:
  • The teacher is the SFT model
  • Reward model based on SFT with frozen bottom layers
  • Policy model initial from SFT model with 8 unfrozen layers


Hydra's architecture: Sharing memory between the trained and frozen models, grey layers are pre-trained GPT-J parameters for the teacher (SFT), blue layers are frozen parameters from fine-tuned reward models, and pink layers are parameters changed during PPO training.





Run: Val-val trlx summarize
1




Run: Val-val trlx summarize
1
Run set 2
1


Examples


Run set
1135