How to Fine-Tune BERT for Text Classification
A code-first reader-friendly kickstart to finetuning BERT for text classification, tf.data and tf.Hub
Created on August 2|Last edited on November 1
Comment
Sections:
What is BERT?
Bidirectional Encoder Representations from Transformers, better known as BERT, is a revolutionary paper by Google that increased the State-of-the-art performance for various NLP tasks and was the stepping stone for many other revolutionary architectures.
It's not an exaggeration to say that BERT set a new direction for the entire domain. It shows clear benefits of using pre-trained models (trained on huge datasets) and transfer learning independent of the downstream tasks.
In this report, we're going to look at using BERT for text classification and provide a ton of code and examples to get you up and running. If you'd like to check out the primary source yourself, here's a link to the annotated paper.

BERT Classification Model
Setting Up BERT for Text Classification
First, we'll install TensorFlow and TensorFlow Model Garden:
import tensorflow as tfprint(tf.version.VERSION)!git clone --depth 1 -b v2.4.0 https://github.com/tensorflow/models.git
We'll also clone the Github Repo for TensorFlow models. A few things of note:
- –depth 1, during cloning, Git will only get the latest copy of the relevant files. It can save you a lot of space and time.
- -b lets us clone a specific branch only.
Please match it with your TensorFlow 2.x version.
# install requirements to use tensorflow/models repository!pip install -Uqr models/official/requirements.txt# you may have to restart the runtime afterwards, also ignore any ERRORS popping up at this step
It's raining imports in here, friends.
import numpy as npimport tensorflow as tfimport tensorflow_hub as hubimport syssys.path.append('models')from official.nlp.data import classifier_data_libfrom official.nlp.bert import tokenizationfrom official.nlp import optimizationimport matplotlib.pyplot as plt%matplotlib inlineimport seaborn as snssns.set()import wandbfrom wandb.keras import WandbCallback
A quick sanity-check of different versions and dependencies installed:
print("TF Version: ", tf.__version__)print("Eager mode: ", tf.executing_eagerly())print("Hub version: ", hub.__version__)print("GPU is", "available" if tf.config.experimental.list_physical_devices("GPU") else "NOT AVAILABLE")
Let's Get the Dataset
The dataset we'll use today is provided via the Quora Insincere Questions Classification competition on Kaggle.
Please feel free to download the training set from Kaggle or use the link below to download the train.csv from that competition:
Decompress and Read the Data into a pandas DataFrame:
Next, run the following:
# TO LOAD DATA FROM ARCHIVE LINKimport numpy as npimport pandas as pdfrom sklearn.model_selection import train_test_splitdf = pd.read_csv('https://archive.org/download/quora_dataset_train.csv/quora_dataset_train.csv.zip',compression='zip',low_memory=False)print(df.shape)df.head(10)# label 0 == non toxic# label 1 == toxic
Alright. Now, let's quickly visualize that data in a W&B Table:
Run set
14
Let's Explore
The Label Distribution
It's a good idea to understand the data you're working with before you really dig into modeling. Here, we're going to walk through our label distribution. how long our data points are, make certain that our test and train sets are well distributed, and a few other preliminary tasks. First though, let's look at label distribution by running:
print(df['target'].value_counts())df['target'].value_counts().plot.bar()plt.yscale('log');plt.title('Distribution of Labels')

Label Distribution
Word Length and Character Length
Now, let's run a few lines of code to understand the text data we're working with here.
print('Average word length of questions in dataset is {0:.0f}.'.format(np.mean(df['question_text'].apply(lambda x: len(x.split())))))print('Max word length of questions in dataset is {0:.0f}.'.format(np.max(df['question_text'].apply(lambda x: len(x.split())))))print('Average character length of questions in dataset is {0:.0f}.'.format(np.mean(df['question_text'].apply(lambda x: len(x)))))

Preparing Training and Testing Data for Our BERT Text Classification Tasks
A few notes on our approach here:
💡
- We'll use small portions of the data as the overall dataset would take ages to train. You can of course feel free to include more data by changing train_size
- Since the dataset is very imbalanced we will keep the same distribution in both train and test set by stratifying it based on the labels. In this section, we'll be analyzing our data to make sure we did a good job at this.
train_df, remaining = train_test_split(df, random_state=42, train_size=0.1, stratify=df.target.values)valid_df, _ = train_test_split(remaining, random_state=42, train_size=0.01, stratify=remaining.target.values)print(train_df.shape)print(valid_df.shape)
(130612, 3) (11755, 3)
Getting the Word and Character Length for the Sampled Sets
print("FOR TRAIN SET\n")print('Average word length of questions in train set is {0:.0f}.'.format(np.mean(train_df['question_text'].apply(lambda x: len(x.split())))))print('Max word length of questions in train set is {0:.0f}.'.format(np.max(train_df['question_text'].apply(lambda x: len(x.split())))))print('Average character length of questions in train set is {0:.0f}.'.format(np.mean(train_df['question_text'].apply(lambda x: len(x)))))print('Label Distribution in train set is \n{}.'.format(train_df['target'].value_counts()))print("\n\nFOR VALIDATION SET\n")print('Average word length of questions in valid set is {0:.0f}.'.format(np.mean(valid_df['question_text'].apply(lambda x: len(x.split())))))print('Max word length of questions in valid set is {0:.0f}.'.format(np.max(valid_df['question_text'].apply(lambda x: len(x.split())))))print('Average character length of questions in valid set is {0:.0f}.'.format(np.mean(valid_df['question_text'].apply(lambda x: len(x)))))print('Label Distribution in validation set is \n{}.'.format(valid_df['target'].value_counts()))

In other words, it looks like the train and validation set are similar in terms of class imbalance and the various lengths in the question texts.
Analyzing the Distribution of Question Text Length in Words
# TRAIN SETtrain_df['question_text'].apply(lambda x: len(x.split())).plot(kind='hist');plt.yscale('log');plt.title('Distribution of question text length in words')

# VALIDATION SETvalid_df['question_text'].apply(lambda x: len(x.split())).plot(kind='hist');plt.yscale('log');plt.title('Distribution of question text length in words')

Analyzing the Distribution of Question Text Length in Characters
As we dig into our train and validation sets, one other thing we want to check is if the the question text length is mostly similar between the two. Having roughly similar distributions here is generally a smart idea to prevent biasing or overfitting our model.
# TRAIN SETtrain_df['question_text'].apply(lambda x: len(x)).plot(kind='hist');plt.yscale('log');plt.title('Distribution of question text length in characters')

# VALIDATION SETvalid_df['question_text'].apply(lambda x: len(x)).plot(kind='hist');plt.yscale('log');plt.title('Distribution of question text length in characters')

And it is. Even the distribution of question length in words and characters is very similar. It looks like a good train/test split so far.
Taming the Data
Next, we want the dataset to be created and preprocessed on the CPU:
with tf.device('/cpu:0'):train_data = tf.data.Dataset.from_tensor_slices((train_df['question_text'].values, train_df['target'].values))valid_data = tf.data.Dataset.from_tensor_slices((valid_df['question_text'].values, valid_df['target'].values))# lets look at 3 samples from train setfor text,label in train_data.take(3):print(text)print(label)

print(len(train_data))print(len(valid_data))
130612 11755
Okay. Let's BERT.
Let's BERT: Get the Pre-trained BERT Model from TensorFlow Hub

We'll be using the uncased BERT present in the tfhub.
In order to prepare the text to be given to the BERT layer, we need to first tokenize our words. The tokenizer here is present as a model asset and will do uncasing for us as well.
Setting all parameters in the form of a dictionary so any changes, if needed, can be made here:
# Setting some parametersconfig = {'label_list' : [0, 1], # Label categories'max_seq_length' : 128, # maximum length of (token) input sequences'train_batch_size' : 32,'learning_rate': 2e-5,'epochs':5,'optimizer': 'adam','dropout': 0.5,'train_samples': len(train_data),'valid_samples': len(valid_data),'train_split':0.1,'valid_split': 0.01}
Get the BERT layer and tokenizer:
# All details here: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2bert_layer = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2',trainable=True)vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()do_lower_case = bert_layer.resolved_object.do_lower_case.numpy() # checks if the bert layer we are using is uncased or nottokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)
Checking out some of the training samples and their tokenized ids
input_string = "hello world, it is a wonderful day for learning"print(tokenizer.wordpiece_tokenizer.tokenize(input_string))print(tokenizer.convert_tokens_to_ids(tokenizer.wordpiece_tokenizer.tokenize(input_string)))
['hello', 'world', '##,', 'it', 'is', 'a', 'wonderful', 'day', 'for', 'learning'] [7592, 2088, 29623, 2009, 2003, 1037, 6919, 2154, 2005, 4083]
Let's Get That Data Ready: Tokenize and Preprocess Text for BERT
Each line of the dataset is composed of the review text and its label. Data preprocessing consists of transforming text to BERT input features:
- Input Word Ids: Output of our tokenizer, converting each sentence into a set of token ids.
- Input Masks: Since we are padding all the sequences to 128(max sequence length), it is important that we create some sort of mask to make sure those paddings do not interfere with the actual text tokens. Therefore we need a generate input mask blocking the paddings. The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
- Segment Ids: For out task of text classification, since there is only one sequence, the segment_ids/input_type_ids is essentially just a vector of 0s.
Bert was trained on two tasks:
- fill in randomly masked words from a sentence.
- given two sentences, which sentence came first?
# This provides a function to convert row to input features and label,# this uses the classifier_data_lib which is a class defined in the tensorflow model garden we installed earlierdef create_feature(text, label, label_list=config['label_list'], max_seq_length=config['max_seq_length'], tokenizer=tokenizer):"""converts the datapoint into usable features for BERT using the classifier_data_libParameters:text: Input text stringlabel: label associated with the textlabel_list: (list) all possible labelsmax_seq_length: (int) maximum sequence length set for berttokenizer: the tokenizer object instantiated by the files in model assetsReturns:feature.input_ids: The token ids for the input text stringfeature.input_masks: The padding mask generatedfeature.segment_ids: essentially here a vector of 0s since classificationfeature.label_id: the corresponding label id from lable_list [0, 1] here"""# since we only have 1 sentence for classification purpose, textr_b is Noneexample = classifier_data_lib.InputExample(guid = None,text_a = text.numpy(),text_b = None,label = label.numpy())# since only 1 example, the index=0feature = classifier_data_lib.convert_single_example(0, example, label_list,max_seq_length, tokenizer)return (feature.input_ids, feature.input_mask, feature.segment_ids, feature.label_id)
- You want to use Dataset.map to apply this function to each element of the dataset. Dataset.map runs in graph mode and Graph tensors do not have a value.
- In graph mode, you can only use TensorFlow Ops and functions.
So you can't .map this function directly: You need to wrap it in a tf.py_function. The tf.py_function will pass regular tensors (with a value and a .numpy() method to access it), to the wrapped python function.
Wrapping the Python Function into a TensorFlow op for Eager Execution
def create_feature_map(text, label):"""A tensorflow function wrapper to apply the transformation on the dataset.Parameters:Text: the input text string.label: the classification ground truth label associated with the input stringReturns:A tuple of a dictionary and a corresponding label_id with it. The dictionarycontains the input_word_ids, input_mask, input_type_ids"""input_ids, input_mask, segment_ids, label_id = tf.py_function(create_feature, inp=[text, label],Tout=[tf.int32, tf.int32, tf.int32, tf.int32])max_seq_length = config['max_seq_length']# py_func doesn't set the shape of the returned tensors.input_ids.set_shape([max_seq_length])input_mask.set_shape([max_seq_length])segment_ids.set_shape([max_seq_length])label_id.set_shape([])x = {'input_word_ids': input_ids,'input_mask': input_mask,'input_type_ids': segment_ids}return (x, label_id)
The final data point passed to the model is of the format a dictionary as x and labels (the dictionary has keys which should obviously match).
Let the Data Flow: Creating the Final Input Pipeline Using tf.data

Apply the Transformation to our Train and Test Datasets
# Now we will simply apply the transformation to our train and test datasetswith tf.device('/cpu:0'):# traintrain_data = (train_data.map(create_feature_map,num_parallel_calls=tf.data.experimental.AUTOTUNE).shuffle(1000).batch(32, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))# validvalid_data = (valid_data.map(create_feature_map,num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(32, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))
The resulting tf.data.Datasets return (features, labels) pairs, as expected by keras.Model.fit
# train data spec, we can finally see the input datapoint is now converted to the#BERT specific input tensortrain_data.element_spec

Creating, Training & Tracking Our BERT Classification Model.
Let's model our way to glory!!!
Create The Model
There are two outputs from the BERT Layer:
- A pooled_output of shape [batch_size, 768] with representations for the entire input sequences.
- A sequence_output of shape [batch_size, max_seq_length, 768] with representations for each input token (in context).
For the classification task, we are only concerned with the pooled_output:
# Building the model, input ---> BERT Layer ---> Classification Headdef create_model():input_word_ids = tf.keras.layers.Input(shape=(config['max_seq_length'],),dtype=tf.int32,name="input_word_ids")input_mask = tf.keras.layers.Input(shape=(config['max_seq_length'],),dtype=tf.int32,name="input_mask")input_type_ids = tf.keras.layers.Input(shape=(config['max_seq_length'],),dtype=tf.int32,name="input_type_ids")pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, input_type_ids])# for classification we only care about the pooled-output.# At this point we can play around with the classification head based on the# downstream tasks and its complexitydrop = tf.keras.layers.Dropout(config['dropout'])(pooled_output)output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(drop)# inputs coming from the functionmodel = tf.keras.Model(inputs={'input_word_ids': input_word_ids,'input_mask': input_mask,'input_type_ids': input_type_ids},outputs=output)return model
Training Your Model
# Calling the create model function to get the keras based functional modelmodel = create_model()
# using adam with a lr of 2*(10^-5), loss as binary cross entropy as only# 2 classes and similarly binary accuracymodel.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=config['learning_rate']),loss=tf.keras.losses.BinaryCrossentropy(),metrics=[tf.keras.metrics.BinaryAccuracy(),tf.keras.metrics.PrecisionAtRecall(0.5),tf.keras.metrics.Precision(),tf.keras.metrics.Recall()])#model.summary()
Model Summary

Model Architecture Summary
One drawback of the tf hub is that we import the entire module as a layer in keras as a result of which we don't see the parameters and layers in the model summary.
tf.keras.utils.plot_model(model=model, show_shapes=True, dpi=76, )

The official tfhub page states that "All parameters in the module are trainable, and fine-tuning all parameters is the recommended practice." Therefore we will go ahead and train the entire model without freezing anything
Experiment Tracking
Since you are here, I am sure you have a good idea about Weights and Biases but if not, then read along :)
In order to start the experiment tracking, we will be creating 'runs' on W&B,
wandb.init(): It initializes the run with basic project information parameters:
- project: The project name, will create a new project tab where all the experiments for this project will be tracked
- config: A dictionary of all parameters and hyper-parameters we wish to track
- group: optional, but would help us to group by different parameters later on
- job_type: to describe the job type, it would help in grouping different experiments later. eg "train", "evaluate" etc
# Update CONFIG dict with the name of the model.config['model_name'] = 'BERT_EN_UNCASED'print('Training configuration: ', config)# Initialize W&B runrun = wandb.init(project='Finetune-BERT-Text-Classification',config=config,group='BERT_EN_UNCASED',job_type='train')
Now, In order to Log all the different metrics, we will use a simple callback provided by W&B.
WandCallback() : https://docs.wandb.ai/guides/integrations/keras
Yes, it is as simple as adding a callback :D
# Train model# setting low epochs as It starts to overfit with this limited data, please feel free to changeepochs = config['epochs']history = model.fit(train_data,validation_data=valid_data,epochs=epochs,verbose=1,callbacks = [WandbCallback()])run.finish()

Some Training Metrics and Graphs
Run set
14
Lets Evaluate
Let us do an evaluation on the validation set and log the scores using W&B.
wandb.log(): Log a dictionary of scalars (metrics like accuracy and loss) and any other type of wandb object. Here we will pass the evaluation dictionary as it is and log it.
# Initialize a new run for the evaluation-jobrun = wandb.init(project='Finetune-BERT-Text-Classification',config=config,group='BERT_EN_UNCASED',job_type='evaluate')# Model Evaluation on validation setevaluation_results = model.evaluate(valid_data,return_dict=True)# Log scores using wandb.log()wandb.log(evaluation_results)# Finish the runrun.finish()
Saving the Models and Model Versioning
Finally, we're going to look at saving reproducible models with W&B. Namely, with Artifacts.
W&B Artifacts
For saving the models and making it easier to track different experiments, we will be using wandb.artifacts. W&B Artifacts are a way to save your datasets and models.
Within a run, there are three steps for creating and saving a model Artifact.
- Create an empty Artifact with wandb.Artifact().
- Add your model file to the Artifact with wandb.add_file().
- Call wandb.log_artifact() to save the Artifact
# Save modelmodel.save(f"{config['model_name']}.h5")# Initialize a new W&B run for saving the model, changing the job_typerun = wandb.init(project='Finetune-BERT-Text-Classification',config=config,group='BERT_EN_UNCASED',job_type='save')# Save model as Model Artifactartifact = wandb.Artifact(name=f"{config['model_name']}", type='model')artifact.add_file(f"{config['model_name']}.h5")run.log_artifact(artifact)# Finish W&B runrun.finish()
Quick Sneak Peek into the W&B Dashboard
Things to note:
- Grouping of experiments and runs.
- Visualizations of all training logs and metrics.
- Visualizations for system metrics could be useful when training on cloud instances or physical GPU machines
- Hyperparameter tracking in the tabular form.
- Artifacts: Model versioning and storage.

BERT Test Classification Summary & Code
I hope this hands-on tutorial was useful for you, and if you have read so far I hope you have some good takeaway points from here.
Add a comment
Tags: Intermediate, NLP, Classification, Tutorial, BERT, Panels, Plots, Sweeps, Tables, Kaggle, Exemplary, Large Models, LLM, Fine-tuning
Iterate on AI agents and models faster. Try Weights & Biases today.