Kaggle Starter Kernel: Jigsaw Multilingual Toxic Comment Classification
This article compares three models which are trained to compete on Kaggle's Jigsaw Multilingual Toxic Comment Classification.
Created on April 6|Last edited on October 7
Comment
Competition introduction
In this article, I am going to present a comparison between three models, trained to compete on Kaggle's Jigsaw Multilingual Toxic Comment Classification.
- Vanilla neural network as a classification head on top of DistilBERT with class weights
- CNN-based (1D Convolutions) neural network as a classification head on top of DistilBERT with class weights.
Table of Contents
Competition introductionTable of ContentsA Note on the DistilBERT ModelVanilla Neural Network as the Classification TopVanilla neural network with class weightsCNN-Based Classification Top With Class WeightsBattle Between the Three
Our goal in this competition is to find the toxicity probability of a comment. This challenge, at its core, is a binary text classification problem. The dataset is a multilingual one which makes it a bit more challenging than the other text classification-based NLP problems.
Our dataset looks like so:

The labels useful for the competition are present in the toxic columns where 0 indicates a benign, non-toxic comment and 1 indicates a toxic comment. As you would expect the dataset is a highly imbalanced one. The ratio of toxic to non-toxic comments is really skewed. You can see an amazing EDA on the dataset here.
Run set
3
A Note on the DistilBERT Model
The DistilBERT model comes from the good folks at Hugging Face via their mighty library transformers. It makes it extremely easier to plug SoTA NLP-based models in our applications. They also provide the utility functions necessary to prepare your text data ready conforming to what the corresponding NLP model would need. We will mainly be used the following two components from transformers:
Vanilla Neural Network as the Classification Top
A schematic diagram of the model is as follows -

The input sequences and their corresponding masks both are 500-d vectors each. The vectors then pass through the DistilBERT model where the pre-trained weights are utilized and the information then propagates through a full-connected network. This model was not trained using class weights. Let's see how it performs.
The model was only trained for two epochs and we already have a model that yields **84.85%** accuracy on the validation dataset. As this dataset suffers from class imbalance problems, we should also consider the precision and recall for positive classes (toxic comments) but for now, we can skip them.
It takes 1090 seconds to train this model on TPU v3-8, thanks to Kaggle for making them available. As I am logging some demo predictions in between this training time should not be used for any benchmarks
To see how the model would do as it is getting trained, I implemented a simple callback that would:
- Take sample indices to grab test data points.
- Preprocess those data points.
- Generate predictions and process them.
Run set
1
The comments are non-English and they are impossible for me to analyze (without Google Translate). In the next section, when logging the predictions we will also log their English translations (with the help of the googletrans library). The next iteration of the network will also include class weights during training.
Vanilla neural network with class weights
Scikit-Learn provides a utility function to compute the class weights and here's how I used it -
# Account for the class imbalancefrom sklearn.utils import class_weightclass_weights = class_weight.compute_class_weight('balanced',np.unique(y_train),y_train)
If you print out the class_weights you would get - array([0.55288749, 5.22701553]). This means that the model would treat a toxic comment ~9.46x (5.22701553/0.55288749) as important as compared to a non-toxic comment. This way the model would equally penalize under (in this case, the toxic comments) or over-represented classes (the non-toxic comments) in the training set.
Let's see if it changes anything.
Run set
1
CNN-Based Classification Top With Class Weights
We will be using a CNN-based architecture as the classification head that looks like so -

So, here we are using 1D convolutions to exploit the locality in the patterns present in the comments. The model reduces overfitting as well. This is likely because this model is better suited to explore the locality of the comments and in turn figure out the discriminative features that lead to the toxicity/non-toxicity of the comments.
Run set
1
Run set
1
Battle Between the Three
All these results suggest that the model with _CNN-based classification top with class weight (the orange one) performs better than the other two. As mentioned above, CNN's ability to better figure out the local patterns in the comments (that contribute to toxicity/non-toxicity of the comments) makes it a good candidate here. Here are some additional hacks you might want to incorporate to improve performance further -
- Add a learning rate schedule where the learning rate would start very low so that the pre-trained DistilBERT weights do not get broken. Then the learning rate would oscillate for faster convergence. Check out this article for more on this.
- As this is an imbalanced dataset, we should have set the initial bias of the model correctly as mentioned in this tutorial.
Run set
3
Add a comment
Tags: Intermediate, NLP, Classification, scikit-learn, Tutorial, CNN, DistilBERT, Panels, Plots, Kaggle
Iterate on AI agents and models faster. Try Weights & Biases today.