Skip to main content

How to Fine-Tune BERT for Text Classification

A code-first reader-friendly kickstart to finetuning BERT for text classification, tf.data and tf.Hub
Created on August 2|Last edited on November 1

Sections:



What is BERT?

Bidirectional Encoder Representations from Transformers, better known as BERT, is a revolutionary paper by Google that increased the State-of-the-art performance for various NLP tasks and was the stepping stone for many other revolutionary architectures.
It's not an exaggeration to say that BERT set a new direction for the entire domain. It shows clear benefits of using pre-trained models (trained on huge datasets) and transfer learning independent of the downstream tasks.
In this report, we're going to look at using BERT for text classification and provide a ton of code and examples to get you up and running. If you'd like to check out the primary source yourself, here's a link to the annotated paper. 


BERT Classification Model

Setting Up BERT for Text Classification

First, we'll install TensorFlow and TensorFlow Model Garden:
import tensorflow as tf
print(tf.version.VERSION)
!git clone --depth 1 -b v2.4.0 https://github.com/tensorflow/models.git
We'll also clone the Github Repo for TensorFlow models. A few things of note:
  • –depth 1, during cloning, Git will only get the latest copy of the relevant files. It can save you a lot of space and time.
  • -b lets us clone a specific branch only.
Please match it with your TensorFlow 2.x version.
# install requirements to use tensorflow/models repository
!pip install -Uqr models/official/requirements.txt
# you may have to restart the runtime afterwards, also ignore any ERRORS popping up at this step
It's raining imports in here, friends.
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import sys
sys.path.append('models')
from official.nlp.data import classifier_data_lib
from official.nlp.bert import tokenization
from official.nlp import optimization
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set()
import wandb
from wandb.keras import WandbCallback
A quick sanity-check of different versions and dependencies installed:
print("TF Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print("GPU is", "available" if tf.config.experimental.list_physical_devices("GPU") else "NOT AVAILABLE")

Let's Get the Dataset

The dataset we'll use today is provided via the Quora Insincere Questions Classification competition on Kaggle.
Please feel free to download the training set from Kaggle or use the link below to download the train.csv from that competition:

Decompress and Read the Data into a pandas DataFrame:

Next, run the following:
# TO LOAD DATA FROM ARCHIVE LINK
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

df = pd.read_csv('https://archive.org/download/quora_dataset_train.csv/quora_dataset_train.csv.zip',
compression='zip',
low_memory=False)
print(df.shape)
df.head(10)
# label 0 == non toxic
# label 1 == toxic
Alright. Now, let's quickly visualize that data in a W&B Table:

Run set
14


Let's Explore

The Label Distribution

It's a good idea to understand the data you're working with before you really dig into modeling. Here, we're going to walk through our label distribution. how long our data points are, make certain that our test and train sets are well distributed, and a few other preliminary tasks. First though, let's look at label distribution by running:
print(df['target'].value_counts())
df['target'].value_counts().plot.bar()
plt.yscale('log');
plt.title('Distribution of Labels')

Label Distribution

Word Length and Character Length

Now, let's run a few lines of code to understand the text data we're working with here.
print('Average word length of questions in dataset is {0:.0f}.'.format(np.mean(df['question_text'].apply(lambda x: len(x.split())))))
print('Max word length of questions in dataset is {0:.0f}.'.format(np.max(df['question_text'].apply(lambda x: len(x.split())))))
print('Average character length of questions in dataset is {0:.0f}.'.format(np.mean(df['question_text'].apply(lambda x: len(x)))))



Preparing Training and Testing Data for Our BERT Text Classification Tasks

A few notes on our approach here:
💡
  • We'll use small portions of the data as the overall dataset would take ages to train. You can of course feel free to include more data by changing train_size
  • Since the dataset is very imbalanced we will keep the same distribution in both train and test set by stratifying it based on the labels. In this section, we'll be analyzing our data to make sure we did a good job at this.
train_df, remaining = train_test_split(df, random_state=42, train_size=0.1, stratify=df.target.values)
valid_df, _ = train_test_split(remaining, random_state=42, train_size=0.01, stratify=remaining.target.values)
print(train_df.shape)
print(valid_df.shape)
(130612, 3) (11755, 3)

Getting the Word and Character Length for the Sampled Sets

print("FOR TRAIN SET\n")
print('Average word length of questions in train set is {0:.0f}.'.format(np.mean(train_df['question_text'].apply(lambda x: len(x.split())))))
print('Max word length of questions in train set is {0:.0f}.'.format(np.max(train_df['question_text'].apply(lambda x: len(x.split())))))
print('Average character length of questions in train set is {0:.0f}.'.format(np.mean(train_df['question_text'].apply(lambda x: len(x)))))
print('Label Distribution in train set is \n{}.'.format(train_df['target'].value_counts()))
print("\n\nFOR VALIDATION SET\n")
print('Average word length of questions in valid set is {0:.0f}.'.format(np.mean(valid_df['question_text'].apply(lambda x: len(x.split())))))
print('Max word length of questions in valid set is {0:.0f}.'.format(np.max(valid_df['question_text'].apply(lambda x: len(x.split())))))
print('Average character length of questions in valid set is {0:.0f}.'.format(np.mean(valid_df['question_text'].apply(lambda x: len(x)))))
print('Label Distribution in validation set is \n{}.'.format(valid_df['target'].value_counts()))
In other words, it looks like the train and validation set are similar in terms of class imbalance and the various lengths in the question texts.


Analyzing the Distribution of Question Text Length in Words

# TRAIN SET
train_df['question_text'].apply(lambda x: len(x.split())).plot(kind='hist');
plt.yscale('log');
plt.title('Distribution of question text length in words')


# VALIDATION SET
valid_df['question_text'].apply(lambda x: len(x.split())).plot(kind='hist');
plt.yscale('log');
plt.title('Distribution of question text length in words')




Analyzing the Distribution of Question Text Length in Characters

As we dig into our train and validation sets, one other thing we want to check is if the the question text length is mostly similar between the two. Having roughly similar distributions here is generally a smart idea to prevent biasing or overfitting our model.
# TRAIN SET
train_df['question_text'].apply(lambda x: len(x)).plot(kind='hist');
plt.yscale('log');
plt.title('Distribution of question text length in characters')


# VALIDATION SET
valid_df['question_text'].apply(lambda x: len(x)).plot(kind='hist');
plt.yscale('log');
plt.title('Distribution of question text length in characters')

And it is. Even the distribution of question length in words and characters is very similar. It looks like a good train/test split so far.


Taming the Data

Next, we want the dataset to be created and preprocessed on the CPU:
with tf.device('/cpu:0'):
train_data = tf.data.Dataset.from_tensor_slices((train_df['question_text'].values, train_df['target'].values))
valid_data = tf.data.Dataset.from_tensor_slices((valid_df['question_text'].values, valid_df['target'].values))
# lets look at 3 samples from train set
for text,label in train_data.take(3):
print(text)
print(label)

print(len(train_data))
print(len(valid_data))
130612 11755
Okay. Let's BERT.

Let's BERT: Get the Pre-trained BERT Model from TensorFlow Hub

Source
We'll be using the uncased BERT present in the tfhub.
In order to prepare the text to be given to the BERT layer, we need to first tokenize our words. The tokenizer here is present as a model asset and will do uncasing for us as well.
Setting all parameters in the form of a dictionary so any changes, if needed, can be made here:
# Setting some parameters

config = {'label_list' : [0, 1], # Label categories
'max_seq_length' : 128, # maximum length of (token) input sequences
'train_batch_size' : 32,
'learning_rate': 2e-5,
'epochs':5,
'optimizer': 'adam',
'dropout': 0.5,
'train_samples': len(train_data),
'valid_samples': len(valid_data),
'train_split':0.1,
'valid_split': 0.01
}


Get the BERT layer and tokenizer:

# All details here: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2

bert_layer = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2',
trainable=True)
vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy() # checks if the bert layer we are using is uncased or not
tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)


Checking out some of the training samples and their tokenized ids

input_string = "hello world, it is a wonderful day for learning"
print(tokenizer.wordpiece_tokenizer.tokenize(input_string))
print(tokenizer.convert_tokens_to_ids(tokenizer.wordpiece_tokenizer.tokenize(input_string)))
['hello', 'world', '##,', 'it', 'is', 'a', 'wonderful', 'day', 'for', 'learning'] [7592, 2088, 29623, 2009, 2003, 1037, 6919, 2154, 2005, 4083]


Let's Get That Data Ready: Tokenize and Preprocess Text for BERT

Each line of the dataset is composed of the review text and its label. Data preprocessing consists of transforming text to BERT input features:
  • Input Word Ids: Output of our tokenizer, converting each sentence into a set of token ids.
  • Input Masks: Since we are padding all the sequences to 128(max sequence length), it is important that we create some sort of mask to make sure those paddings do not interfere with the actual text tokens. Therefore we need a generate input mask blocking the paddings. The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
  • Segment Ids: For out task of text classification, since there is only one sequence, the segment_ids/input_type_ids is essentially just a vector of 0s.
Bert was trained on two tasks:
  1. fill in randomly masked words from a sentence.
  2. given two sentences, which sentence came first?
# This provides a function to convert row to input features and label,
# this uses the classifier_data_lib which is a class defined in the tensorflow model garden we installed earlier

def create_feature(text, label, label_list=config['label_list'], max_seq_length=config['max_seq_length'], tokenizer=tokenizer):
"""
converts the datapoint into usable features for BERT using the classifier_data_lib

Parameters:
text: Input text string
label: label associated with the text
label_list: (list) all possible labels
max_seq_length: (int) maximum sequence length set for bert
tokenizer: the tokenizer object instantiated by the files in model assets

Returns:
feature.input_ids: The token ids for the input text string
feature.input_masks: The padding mask generated
feature.segment_ids: essentially here a vector of 0s since classification
feature.label_id: the corresponding label id from lable_list [0, 1] here

"""

# since we only have 1 sentence for classification purpose, textr_b is None
example = classifier_data_lib.InputExample(guid = None,
text_a = text.numpy(),
text_b = None,
label = label.numpy())
# since only 1 example, the index=0
feature = classifier_data_lib.convert_single_example(0, example, label_list,
max_seq_length, tokenizer)

return (feature.input_ids, feature.input_mask, feature.segment_ids, feature.label_id)

  • You want to use Dataset.map to apply this function to each element of the dataset. Dataset.map runs in graph mode and Graph tensors do not have a value.
  • In graph mode, you can only use TensorFlow Ops and functions.
So you can't .map this function directly: You need to wrap it in a tf.py_function. The tf.py_function will pass regular tensors (with a value and a .numpy() method to access it), to the wrapped python function.

Wrapping the Python Function into a TensorFlow op for Eager Execution

def create_feature_map(text, label):
"""
A tensorflow function wrapper to apply the transformation on the dataset.
Parameters:
Text: the input text string.
label: the classification ground truth label associated with the input string

Returns:
A tuple of a dictionary and a corresponding label_id with it. The dictionary
contains the input_word_ids, input_mask, input_type_ids
"""

input_ids, input_mask, segment_ids, label_id = tf.py_function(create_feature, inp=[text, label],
Tout=[tf.int32, tf.int32, tf.int32, tf.int32])
max_seq_length = config['max_seq_length']

# py_func doesn't set the shape of the returned tensors.
input_ids.set_shape([max_seq_length])
input_mask.set_shape([max_seq_length])
segment_ids.set_shape([max_seq_length])
label_id.set_shape([])

x = {
'input_word_ids': input_ids,
'input_mask': input_mask,
'input_type_ids': segment_ids
}
return (x, label_id)
The final data point passed to the model is of the format a dictionary as x and labels (the dictionary has keys which should obviously match).

Let the Data Flow: Creating the Final Input Pipeline Using tf.data



Apply the Transformation to our Train and Test Datasets

# Now we will simply apply the transformation to our train and test datasets
with tf.device('/cpu:0'):
# train
train_data = (train_data.map(create_feature_map,
num_parallel_calls=tf.data.experimental.AUTOTUNE)

.shuffle(1000)
.batch(32, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE))

# valid
valid_data = (valid_data.map(create_feature_map,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
.batch(32, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE))
The resulting tf.data.Datasets return (features, labels) pairs, as expected by keras.Model.fit
# train data spec, we can finally see the input datapoint is now converted to the
#BERT specific input tensor
train_data.element_spec



Creating, Training & Tracking Our BERT Classification Model.

Let's model our way to glory!!!

Create The Model

There are two outputs from the BERT Layer:
  • A pooled_output of shape [batch_size, 768] with representations for the entire input sequences.
  • A sequence_output of shape [batch_size, max_seq_length, 768] with representations for each input token (in context).
For the classification task, we are only concerned with the pooled_output:
# Building the model, input ---> BERT Layer ---> Classification Head
def create_model():
input_word_ids = tf.keras.layers.Input(shape=(config['max_seq_length'],),
dtype=tf.int32,
name="input_word_ids")

input_mask = tf.keras.layers.Input(shape=(config['max_seq_length'],),
dtype=tf.int32,
name="input_mask")

input_type_ids = tf.keras.layers.Input(shape=(config['max_seq_length'],),
dtype=tf.int32,
name="input_type_ids")


pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, input_type_ids])
# for classification we only care about the pooled-output.
# At this point we can play around with the classification head based on the
# downstream tasks and its complexity

drop = tf.keras.layers.Dropout(config['dropout'])(pooled_output)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(drop)

# inputs coming from the function
model = tf.keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids},
outputs=output)

return model


Training Your Model

# Calling the create model function to get the keras based functional model
model = create_model()

# using adam with a lr of 2*(10^-5), loss as binary cross entropy as only
# 2 classes and similarly binary accuracy
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=config['learning_rate']),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryAccuracy(),
tf.keras.metrics.PrecisionAtRecall(0.5),
tf.keras.metrics.Precision(),
tf.keras.metrics.Recall()])
#model.summary()

Model Summary
Model Architecture Summary
One drawback of the tf hub is that we import the entire module as a layer in keras as a result of which we don't see the parameters and layers in the model summary.
tf.keras.utils.plot_model(model=model, show_shapes=True, dpi=76, )

The official tfhub page states that "All parameters in the module are trainable, and fine-tuning all parameters is the recommended practice." Therefore we will go ahead and train the entire model without freezing anything

Experiment Tracking

Since you are here, I am sure you have a good idea about Weights and Biases but if not, then read along :)
In order to start the experiment tracking, we will be creating 'runs' on W&B,
wandb.init(): It initializes the run with basic project information parameters:
  • project: The project name, will create a new project tab where all the experiments for this project will be tracked
  • config: A dictionary of all parameters and hyper-parameters we wish to track
  • group: optional, but would help us to group by different parameters later on
  • job_type: to describe the job type, it would help in grouping different experiments later. eg "train", "evaluate" etc
# Update CONFIG dict with the name of the model.
config['model_name'] = 'BERT_EN_UNCASED'
print('Training configuration: ', config)

# Initialize W&B run
run = wandb.init(project='Finetune-BERT-Text-Classification',
config=config,
group='BERT_EN_UNCASED',
job_type='train')
Now, In order to Log all the different metrics, we will use a simple callback provided by W&B.
WandCallback() : https://docs.wandb.ai/guides/integrations/keras
Yes, it is as simple as adding a callback :D
# Train model
# setting low epochs as It starts to overfit with this limited data, please feel free to change
epochs = config['epochs']
history = model.fit(train_data,
validation_data=valid_data,
epochs=epochs,
verbose=1,
callbacks = [WandbCallback()])
run.finish()



Some Training Metrics and Graphs


Run set
14



Lets Evaluate

Let us do an evaluation on the validation set and log the scores using W&B.
wandb.log(): Log a dictionary of scalars (metrics like accuracy and loss) and any other type of wandb object. Here we will pass the evaluation dictionary as it is and log it.
# Initialize a new run for the evaluation-job
run = wandb.init(project='Finetune-BERT-Text-Classification',
config=config,
group='BERT_EN_UNCASED',
job_type='evaluate')



# Model Evaluation on validation set
evaluation_results = model.evaluate(valid_data,return_dict=True)

# Log scores using wandb.log()
wandb.log(evaluation_results)

# Finish the run
run.finish()


Saving the Models and Model Versioning

Finally, we're going to look at saving reproducible models with W&B. Namely, with Artifacts.

W&B Artifacts

For saving the models and making it easier to track different experiments, we will be using wandb.artifacts. W&B Artifacts are a way to save your datasets and models.
Within a run, there are three steps for creating and saving a model Artifact.
  • Create an empty Artifact with wandb.Artifact().
  • Add your model file to the Artifact with wandb.add_file().
  • Call wandb.log_artifact() to save the Artifact
# Save model
model.save(f"{config['model_name']}.h5")

# Initialize a new W&B run for saving the model, changing the job_type
run = wandb.init(project='Finetune-BERT-Text-Classification',
config=config,
group='BERT_EN_UNCASED',
job_type='save')


# Save model as Model Artifact
artifact = wandb.Artifact(name=f"{config['model_name']}", type='model')
artifact.add_file(f"{config['model_name']}.h5")
run.log_artifact(artifact)

# Finish W&B run
run.finish()


Quick Sneak Peek into the W&B Dashboard

Things to note:
  • Grouping of experiments and runs.
  • Visualizations of all training logs and metrics.
  • Visualizations for system metrics could be useful when training on cloud instances or physical GPU machines
  • Hyperparameter tracking in the tabular form.
  • Artifacts: Model versioning and storage.



BERT Test Classification Summary & Code

I hope this hands-on tutorial was useful for you, and if you have read so far I hope you have some good takeaway points from here.
The full code of this post can be found here
Iterate on AI agents and models faster. Try Weights & Biases today.