Skip to main content

How to Modify the Loss Function for a Class Imbalanced Binary Classifier in Tensorflow

In this report we will learn how to modify the loss function for a class imbalanced binary classifier.
Created on August 11|Last edited on October 26

Problem

I am trying to apply deep learning for a binary classification problem with high class imbalance between target classes (500k, 31K). I want to write a custom loss function which should be like:

minimize(100-((predicted_smallerclass)/(total_smallerclass))*100)

Cross Entropy - The default loss function

When we encounter high degrees of class balance, as in the example above (class A has 500k examples, whereas class B has only 31k). By default we use softmax or sigmoid activation functions in the output layers and Cross Entropy as the loss for classification tasks.

Regular cross entropy loss is defined like this:

loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
               = -x[class] + log(\sum_j exp(x[j]))

Cross Entropy does not naturally counter class imbalance problem. However with some minor modifications, we can achieve a really good classifier.

Modified Loss Function

We can modify the this cross entropy loss function for class imbalance by incorporating class weights. Class weights provide different weightage to different classes.

Thus the modified loss function would be:

loss(x, class) = weights[class] * -x[class] + weights[class] * log(\sum_j exp(x[j]))

This is the proposed code:

# compute class weights for you binary classification problem.
ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([[ratio, 1.0 - ratio]])

# logit is an output of a dense (fully-connected) layer.
logits = ... # shape [batch_size, 2]

# this is the weight for each datapoint, depending on its label
weight_per_label = tf.transpose( tf.matmul(labels, tf.transpose(class_weight)) ) #shape [1, batch_size]

xent = tf.mul(weight_per_label, tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="xent_raw") #shape [1, batch_size]
loss = tf.reduce_mean(xent) #shape 1

Note: This answer was inspired from the answers in this Stack Overflow thread.

Effect of class weights

In this toy experiment we will demonstrate the effect of class weights in image classification for data imbalance problem. For more techniques on dealing with class imbalance problem check "Simple Ways to Tackle Class Imbalance" by Aritra Roy Gosthipaty and me.

image.png

Preview from the report

We have created a Cifar-2 dataset the details of which can be found in the linked report. The accuracy and loss plot shown below clearly shows that with class weights used along with the cross entropy loss the classification result is improved. It acts as a regularizer.




Run set
2