How to Fine-Tune BERT for Text Classification

A code-first reader-friendly kickstart to finetuning BERT for text classification, and tf.Hub. Made by Akshay Uppal using Weights & Biases
Akshay Uppal


  1. What is BERT?
  2. Setting up BERT for Text Classification
  3. The Quora Dataset We'll Use for Our Task
  4. Exploring the Dataset
  5. Getting BERT
  6. Using
  7. Creating, Training, and Tracking Our BERT Model
  8. Saving & Versioning the Model
  9. BERT Text Classification & Code

What is BERT?

Bidirectional Encoder Representations from Transformers, better known as BERT, is a revolutionary paper by Google that increased the State-of-the-art performance for various NLP tasks and was the stepping stone for many other revolutionary architectures.
It's not an exaggeration to say that BERT set a new direction for the entire domain. It shows clear benefits of using pre-trained models (trained on huge datasets) and transfer learning independent of the downstream tasks.
In this report, we're going to look at using BERT for text classification and provide a ton of code and examples to get you up and running. If you'd like to check out the primary source yourself, here's a link to the annotated paper.
BERT Classification Model

Setting Up BERT for Text Classification

First, we'll install TensorFlow and TensorFlow Model Garden:
import tensorflow as tfprint(tf.version.VERSION)!git clone --depth 1 -b v2.4.0
We'll also clone the Github Repo for TensorFlow models. A few things of note:
Please match it with your TensorFlow 2.x version.
# install requirements to use tensorflow/models repository!pip install -Uqr models/official/requirements.txt# you may have to restart the runtime afterwards, also ignore any ERRORS popping up at this step
It's raining imports in here, friends.
import numpy as npimport tensorflow as tfimport tensorflow_hub as hubimport syssys.path.append('models')from import classifier_data_libfrom official.nlp.bert import tokenizationfrom official.nlp import optimizationimport matplotlib.pyplot as plt%matplotlib inlineimport seaborn as snssns.set()import wandbfrom wandb.keras import WandbCallback
A quick sanity-check of different versions and dependencies installed:
print("TF Version: ", tf.__version__)print("Eager mode: ", tf.executing_eagerly())print("Hub version: ", hub.__version__)print("GPU is", "available" if tf.config.experimental.list_physical_devices("GPU") else "NOT AVAILABLE")

Let's Get the Dataset

The dataset we'll use today is provided via the Quora Insincere Questions Classification competition on Kaggle.
Please feel free to download the training set from Kaggle or use the link below to download the train.csv from that competition:

Decompress and Read the Data into a pandas DataFrame:

Next, run the following:
# TO LOAD DATA FROM ARCHIVE LINKimport numpy as npimport pandas as pdfrom sklearn.model_selection import train_test_splitdf = pd.read_csv('', compression='zip', low_memory=False)print(df.shape)df.head(10)# label 0 == non toxic# label 1 == toxic
Alright. Now, let's quickly visualize that data in a W&B Table:

Let's Explore

The Label Distribution

It's a good idea to understand the data you're working with before you really dig into modeling. Here, we're going to walk through our label distribution. how long our data points are, make certain that our test and train sets are well distributed, and a few other preliminary tasks. First though, let's look at label distribution by running:
print(df['target'].value_counts())df['target'].value_counts()'log');plt.title('Distribution of Labels')
Label Distribution

Word Length and Character Length

Now, let's run a few lines of code to understand the text data we're working with here.
print('Average word length of questions in dataset is {0:.0f}.'.format(np.mean(df['question_text'].apply(lambda x: len(x.split())))))print('Max word length of questions in dataset is {0:.0f}.'.format(np.max(df['question_text'].apply(lambda x: len(x.split())))))print('Average character length of questions in dataset is {0:.0f}.'.format(np.mean(df['question_text'].apply(lambda x: len(x)))))

Preparing Training and Testing Data for Our BERT Text Classification Tasks

A few notes on our approach here:
train_df, remaining = train_test_split(df, random_state=42, train_size=0.1,, _ = train_test_split(remaining, random_state=42, train_size=0.01,
(130612, 3) (11755, 3)

Getting the Word and Character Length for the Sampled Sets

print("FOR TRAIN SET\n")print('Average word length of questions in train set is {0:.0f}.'.format(np.mean(train_df['question_text'].apply(lambda x: len(x.split())))))print('Max word length of questions in train set is {0:.0f}.'.format(np.max(train_df['question_text'].apply(lambda x: len(x.split())))))print('Average character length of questions in train set is {0:.0f}.'.format(np.mean(train_df['question_text'].apply(lambda x: len(x)))))print('Label Distribution in train set is \n{}.'.format(train_df['target'].value_counts()))print("\n\nFOR VALIDATION SET\n")print('Average word length of questions in valid set is {0:.0f}.'.format(np.mean(valid_df['question_text'].apply(lambda x: len(x.split())))))print('Max word length of questions in valid set is {0:.0f}.'.format(np.max(valid_df['question_text'].apply(lambda x: len(x.split())))))print('Average character length of questions in valid set is {0:.0f}.'.format(np.mean(valid_df['question_text'].apply(lambda x: len(x)))))print('Label Distribution in validation set is \n{}.'.format(valid_df['target'].value_counts()))
In other words, it looks like the train and validation set are similar in terms of class imbalance and the various lengths in the question texts.

Analyzing the Distribution of Question Text Length in Words

# TRAIN SET train_df['question_text'].apply(lambda x: len(x.split())).plot(kind='hist');plt.yscale('log');plt.title('Distribution of question text length in words')
# VALIDATION SETvalid_df['question_text'].apply(lambda x: len(x.split())).plot(kind='hist');plt.yscale('log');plt.title('Distribution of question text length in words')

Analyzing the Distribution of Question Text Length in Characters

As we dig into our train and validation sets, one other thing we want to check is if the the question text length is mostly similar between the two. Having roughly similar distributions here is generally a smart idea to prevent biasing or overfitting our model.
# TRAIN SETtrain_df['question_text'].apply(lambda x: len(x)).plot(kind='hist');plt.yscale('log');plt.title('Distribution of question text length in characters')
# VALIDATION SETvalid_df['question_text'].apply(lambda x: len(x)).plot(kind='hist');plt.yscale('log');plt.title('Distribution of question text length in characters')
And it is. Even the distribution of question length in words and characters is very similar. It looks like a good train/test split so far.

Taming the Data

Next, we want the dataset to be created and preprocessed on the CPU:
with tf.device('/cpu:0'): train_data =['question_text'].values, train_df['target'].values)) valid_data =['question_text'].values, valid_df['target'].values)) # lets look at 3 samples from train set for text,label in train_data.take(3): print(text) print(label)
130612 11755
Okay. Let's BERT.

Let's BERT: Get the Pre-trained BERT Model from TensorFlow Hub

We'll be using the uncased BERT present in the tfhub.
In order to prepare the text to be given to the BERT layer, we need to first tokenize our words. The tokenizer here is present as a model asset and will do uncasing for us as well.
Setting all parameters in the form of a dictionary so any changes, if needed, can be made here:
# Setting some parametersconfig = {'label_list' : [0, 1], # Label categories 'max_seq_length' : 128, # maximum length of (token) input sequences 'train_batch_size' : 32, 'learning_rate': 2e-5, 'epochs':5, 'optimizer': 'adam', 'dropout': 0.5, 'train_samples': len(train_data), 'valid_samples': len(valid_data), 'train_split':0.1, 'valid_split': 0.01 }

Get the BERT layer and tokenizer:

# All details here: = hub.KerasLayer('', trainable=True)vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()do_lower_case = bert_layer.resolved_object.do_lower_case.numpy() # checks if the bert layer we are using is uncased or nottokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)

Checking out some of the training samples and their tokenized ids

input_string = "hello world, it is a wonderful day for learning"print(tokenizer.wordpiece_tokenizer.tokenize(input_string))print(tokenizer.convert_tokens_to_ids(tokenizer.wordpiece_tokenizer.tokenize(input_string)))
['hello', 'world', '##,', 'it', 'is', 'a', 'wonderful', 'day', 'for', 'learning'] [7592, 2088, 29623, 2009, 2003, 1037, 6919, 2154, 2005, 4083]

Let's Get That Data Ready: Tokenize and Preprocess Text for BERT

Each line of the dataset is composed of the review text and its label. Data preprocessing consists of transforming text to BERT input features:
Bert was trained on two tasks:
  1. fill in randomly masked words from a sentence.
  2. given two sentences, which sentence came first?
# This provides a function to convert row to input features and label, # this uses the classifier_data_lib which is a class defined in the tensorflow model garden we installed earlierdef create_feature(text, label, label_list=config['label_list'], max_seq_length=config['max_seq_length'], tokenizer=tokenizer): """ converts the datapoint into usable features for BERT using the classifier_data_lib Parameters: text: Input text string label: label associated with the text label_list: (list) all possible labels max_seq_length: (int) maximum sequence length set for bert tokenizer: the tokenizer object instantiated by the files in model assets Returns: feature.input_ids: The token ids for the input text string feature.input_masks: The padding mask generated feature.segment_ids: essentially here a vector of 0s since classification feature.label_id: the corresponding label id from lable_list [0, 1] here """ # since we only have 1 sentence for classification purpose, textr_b is None example = classifier_data_lib.InputExample(guid = None, text_a = text.numpy(), text_b = None, label = label.numpy()) # since only 1 example, the index=0 feature = classifier_data_lib.convert_single_example(0, example, label_list, max_seq_length, tokenizer) return (feature.input_ids, feature.input_mask, feature.segment_ids, feature.label_id)
So you can't .map this function directly: You need to wrap it in a tf.py_function. The tf.py_function will pass regular tensors (with a value and a .numpy() method to access it), to the wrapped python function.

Wrapping the Python Function into a TensorFlow op for Eager Execution

def create_feature_map(text, label): """ A tensorflow function wrapper to apply the transformation on the dataset. Parameters: Text: the input text string. label: the classification ground truth label associated with the input string Returns: A tuple of a dictionary and a corresponding label_id with it. The dictionary contains the input_word_ids, input_mask, input_type_ids """ input_ids, input_mask, segment_ids, label_id = tf.py_function(create_feature, inp=[text, label], Tout=[tf.int32, tf.int32, tf.int32, tf.int32]) max_seq_length = config['max_seq_length'] # py_func doesn't set the shape of the returned tensors. input_ids.set_shape([max_seq_length]) input_mask.set_shape([max_seq_length]) segment_ids.set_shape([max_seq_length]) label_id.set_shape([]) x = { 'input_word_ids': input_ids, 'input_mask': input_mask, 'input_type_ids': segment_ids } return (x, label_id)
The final data point passed to the model is of the format a dictionary as x and labels (the dictionary has keys which should obviously match).

Let the Data Flow: Creating the Final Input Pipeline Using

Apply the Transformation to our Train and Test Datasets

# Now we will simply apply the transformation to our train and test datasetswith tf.device('/cpu:0'): # train train_data = (, .shuffle(1000) .batch(32, drop_remainder=True) .prefetch( # valid valid_data = (, .batch(32, drop_remainder=True) .prefetch(
The resulting return (features, labels) pairs, as expected by
# train data spec, we can finally see the input datapoint is now converted to the #BERT specific input tensortrain_data.element_spec

Creating, Training & Tracking Our BERT Classification Model.

Let's model our way to glory!!!

Create The Model

There are two outputs from the BERT Layer:
For the classification task, we are only concerned with the pooled_output:
# Building the model, input ---> BERT Layer ---> Classification Headdef create_model(): input_word_ids = tf.keras.layers.Input(shape=(config['max_seq_length'],), dtype=tf.int32, name="input_word_ids") input_mask = tf.keras.layers.Input(shape=(config['max_seq_length'],), dtype=tf.int32, name="input_mask") input_type_ids = tf.keras.layers.Input(shape=(config['max_seq_length'],), dtype=tf.int32, name="input_type_ids") pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, input_type_ids]) # for classification we only care about the pooled-output. # At this point we can play around with the classification head based on the # downstream tasks and its complexity drop = tf.keras.layers.Dropout(config['dropout'])(pooled_output) output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(drop) # inputs coming from the function model = tf.keras.Model( inputs={ 'input_word_ids': input_word_ids, 'input_mask': input_mask, 'input_type_ids': input_type_ids}, outputs=output) return model

Training Your Model

# Calling the create model function to get the keras based functional modelmodel = create_model()
# using adam with a lr of 2*(10^-5), loss as binary cross entropy as only # 2 classes and similarly binary accuracymodel.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=config['learning_rate']), loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.PrecisionAtRecall(0.5), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])#model.summary()
Model Summary
Model Architecture Summary
One drawback of the tf hub is that we import the entire module as a layer in keras as a result of which we don't see the parameters and layers in the model summary.
tf.keras.utils.plot_model(model=model, show_shapes=True, dpi=76, )
The official tfhub page states that "All parameters in the module are trainable, and fine-tuning all parameters is the recommended practice." Therefore we will go ahead and train the entire model without freezing anything

Experiment Tracking

Since you are here, I am sure you have a good idea about Weights and Biases but if not, then read along :)
In order to start the experiment tracking, we will be creating 'runs' on W&B,
wandb.init(): It initializes the run with basic project information parameters:
# Update CONFIG dict with the name of the model.config['model_name'] = 'BERT_EN_UNCASED'print('Training configuration: ', config)# Initialize W&B runrun = wandb.init(project='Finetune-BERT-Text-Classification', config=config, group='BERT_EN_UNCASED', job_type='train')
Now, In order to Log all the different metrics, we will use a simple callback provided by W&B.
WandCallback() :
Yes, it is as simple as adding a callback :D
# Train model# setting low epochs as It starts to overfit with this limited data, please feel free to changeepochs = config['epochs']history =, validation_data=valid_data, epochs=epochs, verbose=1, callbacks = [WandbCallback()])run.finish()

Some Training Metrics and Graphs

Lets Evaluate

Let us do an evaluation on the validation set and log the scores using W&B.
wandb.log(): Log a dictionary of scalars (metrics like accuracy and loss) and any other type of wandb object. Here we will pass the evaluation dictionary as it is and log it.
# Initialize a new run for the evaluation-jobrun = wandb.init(project='Finetune-BERT-Text-Classification', config=config, group='BERT_EN_UNCASED', job_type='evaluate')# Model Evaluation on validation setevaluation_results = model.evaluate(valid_data,return_dict=True)# Log scores using wandb.log()wandb.log(evaluation_results)# Finish the runrun.finish()

Saving the Models and Model Versioning

Finally, we're going to look at saving reproducible models with W&B. Namely, with Artifacts.

W&B Artifacts

For saving the models and making it easier to track different experiments, we will be using wandb.artifacts. W&B Artifacts are a way to save your datasets and models.
Within a run, there are three steps for creating and saving a model Artifact.
# Save"{config['model_name']}.h5")# Initialize a new W&B run for saving the model, changing the job_typerun = wandb.init(project='Finetune-BERT-Text-Classification', config=config, group='BERT_EN_UNCASED', job_type='save')# Save model as Model Artifactartifact = wandb.Artifact(name=f"{config['model_name']}", type='model')artifact.add_file(f"{config['model_name']}.h5")run.log_artifact(artifact)# Finish W&B runrun.finish()

Quick Sneak Peek into the W&B Dashboard

Things to note:

BERT Test Classification Summary & Code

I hope this hands-on tutorial was useful for you, and if you have read so far I hope you have some good takeaway points from here.
The full code of this post can be found here