Jigsaw Multilingual Toxic Comment Classification
Starter code for the "Jigsaw Multilingual Toxic Comment Classification" Kaggle Competition (https://www.kaggle.com/c/jigsaw-multilingual-toxic-comment-classification).
Created on April 5|Last edited on April 7
Comment
Competition introduction
Run set
3
A note on the DistilBERT model
The DistilBERT model comes from the good folks at Hugging Face via their mighty library transformers
. It makes it extremely easier to plug SoTA NLP-based models in our applications. They also provide the utility functions necessary to prepare your text data ready conforming to what the corresponding NLP model would need. We will mainly be used the following two components from transformers
:
transformers.DistilBertTokenizer.from_pretrained('distilbert-base-multilingual-cased')
transformers.TFDistilBertModel.from_pretrained('distilbert-base-multilingual-cased')
Vanilla neural network as the classification top
A schematic diagram of the model is as follows -
The input sequences and their corresponding masks both are 500-d vectors each. The vectors then pass through the DistilBERT model where the pre-trained weights are utilized and the information then propagates through a full-connected network. This model was not trained using class weights. Let's see how it performs.
Run set
1
Vanilla neural network with class weights
Scikit-Learn
provides a utility function to compute the class weights and here's how I used it -
# Account for the class imbalance
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight('balanced',
np.unique(y_train),
y_train)
If you print out the class_weights
you would get - array([0.55288749, 5.22701553])
. This means that the model would treat a toxic comment ~9.46x (5.22701553/0.55288749) important as compared to a non-toxic comment. This way the model would equally penalize under (in this case, the toxic comments) or over-represented classes (the non-toxic comments) in the training set.
Let's see if it changes anything.
Run set
1
CNN-based classification top with class weights
We will be using a CNN-based architecture as the classification head that looks like so -
So, here we are using 1D convolutions to exploit the locality in the patterns present in the comments. The model reduces overfitting as well. This is likely because this model is better suited at exploring the locality of the comments and in turn figure out the discriminative features that lead to the toxicity/non-toxicity of the comments.
Run set
1
Battle between the three
Run set
3
Add a comment