Sentence Classification With Huggingface BERT and W&B

Publish your model insights with interactive plots for performance metrics, predictions, and hyperparameters. Made by Ayush Chaurasia using W&B
Ayush Chaurasia

Introduction

In this tutorial, we’ll build a near state of the art sentence classifier leveraging the power of recent breakthroughs in the field of Natural Language Processing. We’ll focus on an application of transfer learning to NLP. We'll use this to create high performance models with minimal effort on a range of NLP tasks. Google’s BERT allowed researchers to smash multiple benchmarks with minimal fine tuning for specific tasks. As a result, NLP research reproduction and experimentation has become more accessible.

We'll use WandB's hyperparameter Sweeps later on. [Here's a simple notebook to get you started with sweeps.] (https://colab.research.google.com/drive/1SQ-FOgji8AiyrQ08sIVfDiA8OUw4bC12?usp=sharing)

Try BERT fine tuning in a colab →

0_m_kXt3uqZH9e7H4w.png

Loading and Tokenizing The COLA dataset

Splitting and Loading

Splitting the Dataset

We’ll split the the data into train and test set. Divide up our training set to use 90% for training and 10% for validation.

from torch.utils.data import TensorDataset, random_split

# Combine the training inputs into a TensorDataset.
dataset = TensorDataset(input_ids, attention_masks, labels)

# Create a 90-10 train-validation split.

# Calculate the number of samples to include in each set.
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

# Divide the dataset by randomly selecting samples.
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

Loading The Dataset

The next steps require us to guess various hyper-parameter values. We’ll automate that taks by sweeping across all the value combinations of all parameters. For doing this, we’ll initialize a wandb object before starting the training loop. The hyper-parameter value for the current run is saved in wandb.config.parameter_name.


from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import wandb
# WANDB PARAMETER
def ret_dataloader():
    batch_size = wandb.config.batch_size
    print('batch_size = ', batch_size)
    train_dataloader = DataLoader(train_dataset,sampler = RandomSampler(train_datase       t),  batch_size = batch_size)

    validation_dataloader = DataLoader( val_dataset,  sampler = SequentialSampler(val_     dataset),   batch_size = batch_size)
    return train_dataloader,validation_dataloader

Setting Up Model For Training

Set Up A Hyperparameter Sweep

There’s only one step left before we train our model.

We’ll create a configuration file that’ll list all the values a hyper-parameter can take.

Then we’ll initialize our wandb sweep agent to log, compare and visualize the performance of each combination.

The metric we’re looking to maximize is the val_accuracy which we’ll log in the training loop.

In the BERT paper, the authors described the best set of hyper-parameters to perform transfer learning and we’re using that same sets of values for our hyper-parameters.


import wandb
sweep_config = {
    'method': 'random', #grid, random
    'metric': {
      'name': 'val_accuracy',
      'goal': 'maximize'   
    },
    'parameters': {
        'learning_rate': {
            'values': [ 5e-5, 3e-5, 2e-5]
        },
        'batch_size': {
            'values': [16, 32]
        },
        'epochs':{
            'values':[2, 3, 4]
        }
    }
}
sweep_id = wandb.sweep(sweep_config)

The Training Function


import random
import numpy as np

def train():
    wandb.init()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    model = ret_model()
    model.to(device)
    train_dataloader,validation_dataloader = ret_dataloader()
    optimizer = ret_optim(model)
    scheduler = ret_scheduler(train_dataloader,optimizer)
    training_stats = []
    total_t0 = time.time()
    epochs = wandb.config.epochs
    for epoch_i in range(0, epochs):
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
        t0 = time.time()
        total_train_loss = 0
        model.train()
        for step, batch in enumerate(train_dataloader):
            if step % 40 == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)

                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len                 (train_dataloader), elapsed))
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)
            model.zero_grad()        
            loss, logits = model(b_input_ids, 
                                token_type_ids=None, 
                                attention_mask=b_input_mask, 
                                labels=b_labels)
        #Log the train loss
            wandb.log({'train_batch_loss':loss.item()})
            total_train_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
        avg_train_loss = total_train_loss / len(train_dataloader)            
        training_time = format_time(time.time() - t0)
        #Log the Avg. train loss
        wandb.log({'avg_train_loss':avg_train_loss})
        print("")
        print("  Average training loss: {0:.2f}".format(avg_train_loss))
        print("Running Validation...")
        t0 = time.time()
        model.eval()
        total_eval_accuracy = 0
        total_eval_loss = 0
        nb_eval_steps = 0
        # Evaluate data for one epoch
        for batch in validation_dataloader:
            b_input_ids = batch[0].cuda()
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)
            with torch.no_grad():        
                (loss, logits) = model(b_input_ids, 
                                      token_type_ids=None, 
                                      attention_mask=b_input_mask,
                                      labels=b_labels)
                
        
            total_eval_loss += loss.item()
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()
            total_eval_accuracy += flat_accuracy(logits, label_ids)
            
        avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
        print("  Accuracy: {0:.2f}".format(avg_val_accuracy))
        avg_val_loss = total_eval_loss / len(validation_dataloader)
        
        validation_time = format_time(time.time() - t0)
        #Log the Avg. validation accuracy
        wandb.log({'val_accuracy':avg_val_accuracy,'avg_val_loss':avg_val_loss})
        print("  Validation Loss: {0:.2f}".format(avg_val_loss))

Visualizations

Conclusion

Now you have a state of the art BERT model, trained on the best set of hyper-parameter values for performing sentence classification along with various statistical visualizations. We can see the best hyperparameter values from running the sweeps. The highest validation accuracy that was achieved in this batch of sweeps is around 84%.

I encourage you to try running a sweep with more hyperparameter combinations to see if you can improve the performance of the model.

Try BERT fine tuning in a colab ->