In this tutorial, we’ll build a near state of the art sentence classifier leveraging the power of recent breakthroughs in the field of Natural Language Processing. We’ll focus on an application of transfer learning to NLP. We'll use this to create high performance models with minimal effort on a range of NLP tasks. Google’s BERT allowed researchers to smash multiple benchmarks with minimal fine tuning for specific tasks. As a result, NLP research reproduction and experimentation has become more accessible.

We'll use WandB's hyperparameter Sweeps later on. [Here's a simple notebook to get you started with sweeps.] (

Try BERT fine tuning in a colab →


Loading and Tokenizing The COLA dataset

Loading and Tokenizing The COLA dataset

We’ll use The Corpus of Linguistic Acceptability (CoLA) dataset for single sentence classification. It’s a set of sentences labeled as grammatically correct or incorrect. It was first published in May of 2018, and is one of the tests included in the “GLUE Benchmark” on which models like BERT are competing.

import wget
url = '', './') 
df = pd.read_csv("./cola_public/raw/in_domain_train.tsv", delimiter='\t', header=None, 
names=  ['sentence_source', 'label', 'label_notes', 'sentence'])

print('Number of training sentences: {:,}\n'.format(df.shape[0]))


Before we process the entire dataset using this tokenizer, there are a few conditions that we need to satisfy in order to setup the training data for BERT:


from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

max_len = 0
for sent in sentences:

    input_ids = tokenizer.encode(sent, add_special_tokens=True)

    max_len = max(max_len, len(input_ids))
input_ids = []
attention_masks = []
encoded_dict = tokenizer.encode_plus(
  sent,  add_special_tokens = True,  max_length = 64,  pad_to_max_length = True, 
return_attention_mask = 
  True,  return_tensors = 'pt',   

input_ids =, dim=0)
attention_masks =, dim=0)
labels = torch.tensor(labels)


Splitting and Loading

Splitting the Dataset

We’ll split the the data into train and test set. Divide up our training set to use 90% for training and 10% for validation.

from import TensorDataset, random_split

# Combine the training inputs into a TensorDataset.
dataset = TensorDataset(input_ids, attention_masks, labels)

# Create a 90-10 train-validation split.

# Calculate the number of samples to include in each set.
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

# Divide the dataset by randomly selecting samples.
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

Loading The Dataset

The next steps require us to guess various hyper-parameter values. We’ll automate that taks by sweeping across all the value combinations of all parameters. For doing this, we’ll initialize a wandb object before starting the training loop. The hyper-parameter value for the current run is saved in wandb.config.parameter_name.

from import DataLoader, RandomSampler, SequentialSampler
import wandb
def ret_dataloader():
    batch_size = wandb.config.batch_size
    print('batch_size = ', batch_size)
    train_dataloader = DataLoader(train_dataset,sampler = RandomSampler(train_datase       t),  batch_size = batch_size)

    validation_dataloader = DataLoader( val_dataset,  sampler = SequentialSampler(val_     dataset),   batch_size = batch_size)
    return train_dataloader,validation_dataloader

Setting Up Model For Training

Setting Up Model For Training

We’ll use the pre-trained BertForSequenceClassification. We’ll add a single dense or fully-connected layer to perform the task of binary classification, and separate each part of the program as a separate function block.

def ret_model():
    model = BertForSequenceClassification.from_pretrained(
        num_labels = 2, 
        output_attentions = False, 
        output_hidden_states = False,
    return model

def ret_optim(model):
    print('Learning_rate = ',wandb.config.learning_rate )
    optimizer = AdamW(model.parameters(),
                      lr = wandb.config.learning_rate, 
                      eps = 1e-8 
    return optimizer

The num_labels parameter describes the number of final output neurons.

We’ll use an implementation of Adam optimizer with an inbuilt weight-decay mechanism from HuggingFace. We'll pass the learning rate from wandb.config.

We’ll also initialize a learning rate scheduler to perform learning rate decay. The training epoch is also a hyper-parameter so we’ll initialize that using wandb.config .

from transformers import get_linear_schedule_with_warmup
def ret_scheduler(train_dataloader,optimizer):
    epochs = wandb.config.epochs
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0,           
num_training_steps = total_steps)
    return scheduler

Set Up A Hyperparameter Sweep

There’s only one step left before we train our model.

We’ll create a configuration file that’ll list all the values a hyper-parameter can take.

Then we’ll initialize our wandb sweep agent to log, compare and visualize the performance of each combination.

The metric we’re looking to maximize is the val_accuracy which we’ll log in the training loop.

In the BERT paper, the authors described the best set of hyper-parameters to perform transfer learning and we’re using that same sets of values for our hyper-parameters.

import wandb
sweep_config = {
    'method': 'random', #grid, random
    'metric': {
      'name': 'val_accuracy',
      'goal': 'maximize'   
    'parameters': {
        'learning_rate': {
            'values': [ 5e-5, 3e-5, 2e-5]
        'batch_size': {
            'values': [16, 32]
            'values':[2, 3, 4]
sweep_id = wandb.sweep(sweep_config)

The Training Function

import random
import numpy as np

def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ret_model()
    train_dataloader,validation_dataloader = ret_dataloader()
    optimizer = ret_optim(model)
    scheduler = ret_scheduler(train_dataloader,optimizer)
    training_stats = []
    total_t0 = time.time()
    epochs = wandb.config.epochs
    for epoch_i in range(0, epochs):
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
        t0 = time.time()
        total_train_loss = 0
        for step, batch in enumerate(train_dataloader):
            if step % 40 == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)

                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len                 (train_dataloader), elapsed))
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)
            loss, logits = model(b_input_ids, 
        #Log the train loss
            total_train_loss += loss.item()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        avg_train_loss = total_train_loss / len(train_dataloader)            
        training_time = format_time(time.time() - t0)
        #Log the Avg. train loss
        print("  Average training loss: {0:.2f}".format(avg_train_loss))
        print("Running Validation...")
        t0 = time.time()
        total_eval_accuracy = 0
        total_eval_loss = 0
        nb_eval_steps = 0
        # Evaluate data for one epoch
        for batch in validation_dataloader:
            b_input_ids = batch[0].cuda()
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)
            with torch.no_grad():        
                (loss, logits) = model(b_input_ids, 
            total_eval_loss += loss.item()
            logits = logits.detach().cpu().numpy()
            label_ids ='cpu').numpy()
            total_eval_accuracy += flat_accuracy(logits, label_ids)
        avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
        print("  Accuracy: {0:.2f}".format(avg_val_accuracy))
        avg_val_loss = total_eval_loss / len(validation_dataloader)
        validation_time = format_time(time.time() - t0)
        #Log the Avg. validation accuracy
        print("  Validation Loss: {0:.2f}".format(avg_val_loss))



Some of the most essential information about the model performance can be deduced directly from the parallel co-ordinates plot.

Here, we can see how all of the runs performed, given the task of maximizing the validation accuracy. We can see the best hyperparameter values from running the sweeps. The highest validation accuracy that was achieved in this batch of sweeps is around 84%.

If you want a simpler visualization that just compares the time taken by each run, as well as how it performs in optimizing the desired metric, you can refer another useful visualization in the dashboard which compares the validation accuracy Vs the runtime of a particular run.


Now you have a state of the art BERT model, trained on the best set of hyper-parameter values for performing sentence classification along with various statistical visualizations. We can see the best hyperparameter values from running the sweeps. The highest validation accuracy that was achieved in this batch of sweeps is around 84%.

I encourage you to try running a sweep with more hyperparameter combinations to see if you can improve the performance of the model.

Try BERT fine tuning in a colab ->