Text Summarization on HuggingFace
Summarization is basically of two types i.e. Abstractive and Extractive Summarization. Here we will cover both types and will see how we can finetune pretrained T5 transformers on particular dataset.
Created on June 16|Last edited on June 21
Comment
1. Introduction
Summarization is a task of getting short summaries from long documents i.e. news articles or research articles. Basically it can be of two types i.e. Extractive and abstractive summarization.
- Extractive Summarization - Extractive Summarization is a shortening of paragraphs in large documents i.e. news articles, medical publications or research articles throught extracting important information from those documents without keeping context in mind.
- Abstractive Summarization - Abstractive Summarization is quite different from prior basic summarization technique. In prior summarization, resulting summaries may or maynot be meaningful because it's just a process of extracting important sentences from long documents but in abstractive summarization , resulting summaries tries to consider context for whole document and then summarize it accordingly where words maynot be exact similar to given documents. That's why abstractive summarization is commonly used for summarizing news articles. For e.g. inshorts app uses abstractive summarization system to summarize news into short summaries.
Try out in Colab
2. Dataset
We will use Cnn-Daily News Summary dataset here to perform summarization using T5 pretrained model. T5 is a text-to-text transfer transformer model which is trained on unlabelled and labelled data and further finetuned to individual tasks for language modelling. Here we will use T5-small pretrained model to finetune it on wikihow dataset for summarization task.
Custom DataClass
class CustomDataset(Dataset):def __init__(self,dataset,tokenizer,source_len,summ_len):self.dataset = datasetself.tokenizer = tokenizerself.text_len = source_lenself.summ_len = summ_lenself.text = self.dataset['article']self.summary = self.dataset['highlights']def __len__(self):return len(self.text)def __getitem__(self,i):summary = str(self.summary[i])summary = ' '.join(summary.split())text = str(self.text[i])text = ' '.join(text.split())source = self.tokenizer.batch_encode_plus([text],max_length=self.text_len,return_tensors='pt',pad_to_max_length=True) # Each source sequence is encoded and padded to max length in batchestarget = self.tokenizer.batch_encode_plus([summary],max_length=self.summ_len,return_tensors='pt',pad_to_max_length=True) # Each target sequence is encoded and padded to max lenght in batchessource_ids = source['input_ids'].squeeze()source_masks = source['attention_mask'].squeeze()target_ids = target['input_ids'].squeeze()target_masks = target['attention_mask'].squeeze()return {'source_ids':source_ids.to(torch.long),'source_masks':source_masks.to(torch.long),'target_ids':target_ids.to(torch.long),'target_masks':target_masks.to(torch.long)}
3. Pretrained Model
There are basically two architectures are commonly used for summarization i.e. T5 and BART based. In our case we will go for T5 i.e. Text to Text Transfer Transformer based pretrained model. We will use architecture as pretrained model and as tokenizer.
4. Training
Now after creating custom dataclass to load and preprocess summary dataset, we will finetune T5 model for summary generation.
def train(epoch,model,tokenizer,loader,optimizer,device):model.train()print(loader)for step,data in enumerate(loader,0):y = data['target_ids'].to(device)y_ids = y[:,:-1].contiguous() # all ids except last onelm_labels = y[:,1:].clone().detach() # copy the address and detach labellm_labels[y[:,1:]==tokenizer.pad_token_id] = -100 # if it's padded token then assign it to -100source_ids = data['source_ids'].to(device)masks = data['source_masks'].to(device)outputs = model(input_ids = source_ids,attention_mask = masks,decoder_input_ids=y_ids,labels=lm_labels)loss = outputs[0]if step%10==0:print('Epoch:{} | Loss:{}'.format(epoch,loss))wandb.log({'training_loss':loss}) # Saving logs after every 10 steps or batchesoptimizer.zero_grad()loss.backward() # optimize weights through backpropagation lossoptimizer.step()
5. Evaluation
Finally after training, it's time to evaluate model on validation data.
def validation(epoch,tokenizer,model,device,loader):model.eval()predictions = []actual = []with torch.no_grad():for step,data in enumerate(loader,0):ids = data['source_ids'].to(device)mask = data['source_masks'].to(device)y_id = data['target_ids'].to(device)prediction = model.generate(input_ids=ids,attention_mask = mask,num_beams=2,max_length=170,repetition_penalty=2.5,early_stopping=True,length_penalty=1.0)# Decode y_id and prediction #preds = [tokenizer.decode(p,skip_special_tokens=True,clean_up_tokenization_spaces=False) for p in prediction]target = [tokenizer.decode(t,skip_special_tokens=True,clean_up_tokenization_spaces=False) for t in y_id]if step%100==0:print('Completed')print('predictions',preds)print('actual',target)predictions.extend(preds)actual.extend(target)return predictions,actual
6. Integration
Calling all defined functions above here in main function.
def main():wandb.init(project='huggingface')epochs = 1device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')tokenizer = T5TokenizerFast.from_pretrained('t5-base')## Prepare Dataset #### We will use cnn_dailymail summarization dataset for abstractive summarization #dataset = load_dataset('cnn_dailymail','3.0.0')# As we can observe, dataset is too large so for now we will consider just 8k rows for training and 4k rows for validationtrain_dataset = dataset['train'][:8000]val_dataset = dataset['validation'][:4000]#!pip install regeximport redef preprocess(dataset):dataset['article'] = [re.sub('\(.*?\)','',t) for t in dataset['article']]dataset['article'] = [t.replace('--','') for t in dataset['article']]return datasettrain_dataset = preprocess(train_dataset)val_dataset = preprocess(val_dataset)train_dataset = CustomDataset(train_dataset,tokenizer,270,160)val_dataset = CustomDataset(val_dataset,tokenizer,270,160)train_loader = DataLoader(dataset=train_dataset,batch_size=4,shuffle=True,num_workers=0)val_loader = DataLoader(dataset = val_dataset,batch_size=2,num_workers=0)# Define modelmodel = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)optimizer = Adam(model.parameters(),lr=3e-4,amsgrad=True)wandb.watch(model,log='all')# Call train functionfor epoch in range(epochs):train(epoch,model,tokenizer,train_loader,optimizer,device)# Call validation functionfor epoch in range(epochs):pred,target = validation(epoch,tokenizer,model,device,val_loader)print(pred,target)main()
7. Results
Run set
41
Run set 2
0
Run set 3
0
Run set 4
41
Actual text- 'Spokesperson: Experts are investigating how the UK military health care worker got Ebola . It is being decided if the military worker infected in Sierra Leone will return to England . There have been some 24,000 reported cases and 10,000 deaths in the latest Ebola outbreak .', "Iraqi forces make some progress as they seek to advance toward Tikrit . The city, best known to Westerners as Saddam Hussein's birthplace, was taken by ISIS in June .
Summary- 'Ebola outbreak has devastated parts of West Africa, with Sierra Leone, Guinea and Liberia hardest hit . Authorities are investigating how this person was exposed to the virus .', 'Iraqi forces appear to be making progress on the third day of a major offensive . The operation is part of a wide-scale offensive to retake Tikrit and Salahuddin province . Iran has provided advisers, weapons and ammunition to the Iraqi government .'"
Weights & Biases
Weights & Biases helps you keep track of your machine learning experiments. Use our tool to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.
Add a comment