In this report, we explore various methods used to counter class imbalance in image classification problems. To make the study more intuitive, we delve into the realms of binary classification. Instead of using a standard dataset with inherent class imbalance, we built a “synthetic” (not to confuse with GAN generated) dataset out of CIFAR-10 with two classes. We chose car and plane as the two classes. There is no particular reason for the choice, other than we being lazy :sleeping: and those two turning out to be the first two classes in the dataset.
In the CIFAR-10 dataset, each class consists of 5000 samples in the training set. We will call our dataset, the one with only two classes, the CIFAR-2 dataset for obvious reasons. The CIFAR-2 needs to have a stark data imbalance.
We opted to have the following data distribution – plane
: 5000 samples (majority) and car
: 50 samples (minority).
Here’s a quick overview of the methods that we have used to counter class imbalance:
First, we need to address the problem. With an imbalance in the dataset, the model does not learn much about the minority
class. Since the model gets fewer data points for the minority class, it does not learn a good representation for the minority class. This results in an inaccurate deterministic classifier modeled on the dataset.
So let’s draw an analogy here and understand what is up. With a neural network, there are weights and biases, which can be thought of as knobs on a radio. We turn the knob to tune the frequency of our radio. We keep turning the knob both ways until we find the perfect spot. In the neural network, the weights and biases are tuned until the sweet spot is found. The sweet spot for the knobs depends on what we want from the network. For the classification task, this sweet spot would be such that the network builds a function that can map input data to its proper class. Now that the foundation is laid, what would happen if the model (the neural network) sees one of the classes a lot compared to the other? The knobs would be tuned such that the predictions are leaned towards the majority class.
To compare the different methods, we observe the following metrics:
AUC – Remember, in a classification problem, the model, when fed with input, maps it to one of the classes. The prediction is true if the input is indeed the predicted class and is false otherwise. Now let us briefly go through each metric one by one.
True and False Positives/Negatives
Accuracy – This is the ratio of the number of correctly predicted samples to the total number of data.
Confusion Matrix
The table above is a confusion matrix. This is a great tool to conclude the performance of a classification task. With a model ready, we decide a threshold (of probability) and then predict the classes. This matrix is a snapshot of how well the model fared. With the numbers in the cells, one can quickly figure out precision and recall. Confusion matrices often have heat maps. With colors, the user can simply visualize the performance of a model. For a good model, the diagonal of the matrix must light up. Note that the confusion matrix of a perfect model is a diagonal matrix.
Precision – This signifies the percentage of predicted positives that were correctly classified.
Recall – This signifies the percentage of actual positives that were correctly classified.
AUC-ROC – A classifier would output probabilities for different classes. There needs to be a threshold to convert the probability scores to discrete class labels. With all of the above metrics, this threshold needs to be defined before them being evaluated. Pondering on the topic a little further, one would understand that the threshold is a hyperparameter. We would need to experiment with many thresholds and find the best threshold, which can be time-consuming and computationally expensive. This is where a receiver operating characteristic(ROC) curve comes in handy. The ROC plot is drawn with True Positive Rates(TPR) vs. False Positive Rates(FPR). We sample multiple thresholds in a normalized range of [0,1], and for each one of them, the TPR and FPR are computed. The graph is a tremendous visual ally for a quick determination of the best threshold.
(Source)
AUC stands for “Area under the ROC Curve” and provides a quantifiable measure of how robust the classifier is. This metric evaluates the area under the ROC curve. With an ideal ROC, one can notice that the graph is very close to the upper left-hand corner, signifying that the classifier has a threshold that differentiates between the two classes very well. With a left corner hugging graph, the AUC would evaluate closely to 1. On the flip side, if the model were no better than a random guess, TPR and FPR would increase simultaneously parallel to one another, corresponding with an AUC of 0.5.
As stated above, we have built a “synthetic” dataset out of the CIFAR-10 dataset. The aim here is to have a common dataset to model our experiments on. There are many ways to counter class imbalance and get results according to the problem statement that we face. With our approach, we do not create a bias in our minds to incorporate a somewhat hacky technique to get good results. We went ahead with this approach to avoid incorporating domain knowledge(in some smart way) to mitigate class imbalance. We wanted to focus on the techniques and the effects it has to offer. We chose two classes out of 10 known classes. We call this our CIFAR-2 dataset (how creative of us).
Below is the figure, where one can see the distribution of a balanced CIFAR-2 dataset, and it’s imbalanced counterpart.
For the experiments to be conclusive, we have chosen a very minimal model architecture with just 15,458 trainable parameters.
All the models are trained with early stopping.
A note for the readers : A binary classification problem that asks for the classification between the presence and absence of a single class requires a single output neuron. The output neuron is activated with the sigmoid function so that the output of the model is interpreted as the probability of the presence of the class. On the other hand, in a classification problem with two classes, the model should have two output neurons. Both the neurons would output the probability of the presence of the respective classes.
The experimental observations are presented(mental map) in the following manner:
Here, the number of samples is perfectly balanced between the two classes. We wanted to train a minimal model with our dataset. First, we trained with the balanced dataset to see if our model is not overfitting with the given capacity of the model. The result of this is shown below.
We can see that our simple architecture was able to model on our CIFAR-2 dataset properly. The confusion matrix shows that the model can correctly classify both the classes.
Here we talk about our baseline. We will train the model on unbalanced data with nothing to prevent the imbalance(no regularization). This is important as it will provide insights into the problems that are caused by data imbalance.
Note: We are considering "Airplane" as a negative class with one-hot encoding as[1., 0.]
while "Automobile" is a positive class with its one-hot encoding as [0., 1.]
.
One can clearly see that even though we have reached a training accuracy of about 99%, the validation accuracy is 50%, proving that the model has overfitted on the training data. An interesting thing to note here is the validation accuracy. In the test set of the CIFAR-2 dataset, we have 1000 samples per class. Thus the model must be correctly predicting all the majority class and predicting the minority class as majority class. We can safely say that the model did not generalize at all.
The confusion matrix clearly shows that the model was not able to learn features from the under-represented class.
The ROC curve for this model is not left corner hugging, signifying that the TPR and FPR are increasing at the same rate. We will see the AUC score in the comparative study section.
Note 1: The confusion matrix is generated with a threshold of 0.5, while the ROC curve is generated with multiple possible thresholds.
Note 2: In the ROC plot, the blue curve corresponds to the training data, while the red curve corresponds to the testing data.
One of the easiest ways to counter class imbalance is to use class weights wherein we give different weightage to different classes. The number of samples in the classes is considered while computing the class weights. We apply more significant weight to a minority class, which places more emphasis on that class. The classifier thus learns equally from both the classes.
Class weights regularize the loss function. By misclassifying the minority class, a higher loss is incurred by the model since the minority class has a higher weight. This forces the model to learn representations for the minority class. This, however, comes at a price of slightly reduced performance for the majority class.
The easiest way to compute appropriate class weights is to use the sklearn utility function, as shown.
from sklearn.utils import class_weight
cls_wt = class_weight.compute_class_weight('balanced',
np.unique(np.argmax(y_train_im, axis=1)),
np.argmax(y_train_im, axis=1))
class_weights = {0: cls_wt[0], 1:cls_wt[1]}
Note : We are applying np.argmax
as the labels are one-hot encoded.
If you are using Keras to build and train your model, applying class weights in the training loop is as easy as passing an argument to your model.fit
method.
history = model.fit(trainloader,
class_weight=class_weights, # Training with class weights
epochs=EPOCHS,
validation_data=testloader,
callbacks=[WandbCallback(),
early_stopping])
The effect of class weights is shown below.
The model trained stably even though the dataset is unbalanced. This is a huge improvement from the baseline model.
The model is not overfitting on the dataset. Instead, the model learned features to classify the minority class.
The confusion matrix further confirms that the model learned to classify the under-represented class as well. This confusion matrix, even though it is not comparable to our model's confusion matrix, trained with the balanced dataset, the improvement over our baseline model is evident.
This is also evident from the ROC curve, which shifted towards the left corner.
Another way to deal with class imbalance is to use an oversampling strategy. Here, the minority class is sampled, such that we have an equal representation of both the classes. In this report, we will use random oversampling, which is a naive way to oversample the minority class. However, implementing this is easy and will certainly give better results than an unregularized model.
With this technique, it is essential to note that we are artificially reducing the dataset's variance. The resulting model might do poorly upon data shift or data corruption.
We have used tf.data
to build our sampling pipeline. We can use this pipeline to oversample as well as undersample. The pipeline is discussed as follows:
If you have a regular tf.data
trainloader as shown.
trainloader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
testloader = tf.data.Dataset.from_tensor_slices((x_test, y_test))
trainloader = (
trainloader
.shuffle(1024)
.batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE)
)
The easiest way to get a dataloader, which can sample an equal number of training examples from both the majority and minority class, is to first get a positive and a negative dataset. We can use .unbatch
and .filter
to prepare the same from tf.data
trainloader.
negative_ds = (
trainloader
.unbatch()
.filter(lambda features, label: label[1]==classes_dict['plane'])
.repeat())
positive_ds = (
trainloader
.unbatch()
.filter(lambda features, label: label[1]==classes_dict['car'])
.repeat())
The next step is to merge them using experimental.sample_from_datasets
.
resampled_ds = tf.data.experimental.sample_from_datasets([negative_ds,
positive_ds],
weights=[0.5, 0.5])
resampled_ds = (
resampled_ds
.batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE)
)
Note: The weights parameter. This enables equal samples to be drawn from the positive
and the negative
dataset.
resampled_ds
but how can we use this to oversample minority class?To do so, we will specify the steps_per_epoch
argument of the model.fit()
method. The minority class can be oversampled, using a dataloader that provides an equal number of samples for each class(equal to batch size) in each step in an epoch, if each sample in the majority class is at least shown to the model once.
unique, counts = np.unique(np.argmax(y_train_im, axis=1), return_counts=True)
resampled_steps_per_epoch = np.ceil(2.0*counts[0]/BATCH_SIZE)
The result of training the model with this strategy is shown as loss and accuracy metrics.
Even though the model trained, it quickly overfits on the dataset.
Since the model is trained with 313
steps per epoch in our case, we believe it to be the reason for overfitting.
However, it would be interesting to look at other metrics to see if the model learned to classify minority class.
The confusion matrix clearly reveals that even though there is an improvement over the baseline model, the oversampling strategy did not perform better than the class weights method. The minority class(car) have more false negatives(227) compared to that of class weights(198).
Next up is our undersampling strategy wherein we will randomly remove samples from the majority class in a naive implementation. By doing so, we enable the model to learn key features from the minority class at the expense of missing out on key features of the majority class.
There is no golden rule to chose between oversampling and undersampling, and it is usually recommended to try both and see which works the best. However, in our opinion, if the majority class has repetitive features, removing random samples to balance with the minority class will not hurt much. Thus, one can try undersampling in those situations.
Using our described(above) pipeline for sampling, we can undersample the majority class by ensuring that the minority class is shown at least once.
unique, counts = np.unique(np.argmax(y_train_im, axis=1), return_counts=True)
resampled_steps_per_epoch = np.ceil(2.0*counts[1]/BATCH_SIZE)
The result of this strategy is shown below.
The model did not overfit as expected.
There are fluctuations in the training and validation curve, which can be due to the constant shift in the pixel distribution coming from randomly removing samples from the majority class or using a small fraction of the majority class per epoch.
The confusion matrix gives a better picture of the model with few false positives. Thus, both class weights and undersampling resulted in a good model. Notice that the confusion matrix is comparable to that of the class weights. We are not missing on key features of the majority class even though we are dropping many samples.
We want to discuss the next strategy based on the Brain Tumor Segmentation with Deep Neural Networks paper.
In this paper, the authors work on image segmentation(brain tumor segmentation). It is commonly known that medical datasets have a huge class imbalance. To mitigate this, the authors suggest a two-path CNN architecture for segmentation and a two-phase training procedure to combat unbalance in the training dataset.
We wanted to adopt this training procedure in our image classification problem to see if it can be applied for a classification task. We will not talk about the architecture here but would look into the training procedures.
The idea here is simple.
The first phase of training is to train the model with either an oversampled dataset or an undersampled dataset such that all the classes are equiprobable
.
The second phase is to retrain only the output layer
of the model with the original distribution of the dataset. This signifies the idea that in the first phase, the model can learn the representation of both the classes. Following this, the model can also learn about the data imbalance in the second phase.
An excerpt from the paper:
This way we get the best of both worlds: most of the capacity (the lower layers) is used in a balanced way to account for the diversity in all of the classes, while the output probabilities are calibrated correctly (thanks to the retraining of the output layer with the natural frequencies of classes in the data).
Note 1: This is an experimental study of applying a technique in a different deep learning task. The result might not be exciting.
Note 2: We will be discussing the effect of the 2nd phase of training. We have already trained our model with oversampling and undersampling(as discussed above) as our 1st phase of training.
For this experiment, we have trained our model with an oversampling strategy for phase one. We then freeze all the layers of the model except the output layer.
for layer in over_model.layers:
# selecting layer by name
if layer.name != 'last':
layer.trainable = False
Note: We have named the output layer 'last'.
We then train the model in the 2nd phase with an unbalanced dataset. The result of both the training is shown below. You can see that the model trained after the 2nd phase looks close to the baseline model.
Plane being the majority class and car being the minority class we can clearly see the effect of 2nd phase of training with oversampling as the first phase of training. With oversampling technique we achieved a better classifier than the baseline however after 2nd phase of training the number of true positives decreased by a huge margin.
For this experiment we have trained our model with oversampling strategy for phase one. We then freeze all the layers of the model except the output layer.
We then train the model in the 2nd phase with unbalanced dataset. The result of both the training is shown below. You can see that the model trained after 2nd phase looks close to the baseline model.
The plane being the majority class and car being the minority class, we can see the effect of the second phase of training with oversampling as the first phase of training. With the oversampling technique, we achieved a better classifier than the baseline. However, after the second phase of training, the number of true positives decreases significantly.
We have tried out some simple remedies to overcome class imbalance for the image classification task.
We have used AUC, Recall, and Precision as metrics to compare different experiments.
The model trained with a balanced dataset has the best metric scores, which is expected while the model trained with an unbalanced dataset has the worst score. We have seen the confusion matrix for the same.
Class weights, oversampling and undersampling have improved the classifier from the baseline score.
Both class weights and undersampling have resulted in an equally good model while oversampling is close.
Our use of "two-phase training procedure," which we adopted from the Brain Tumor Segmentation with Deep Neural Networks paper ** did not** work out in our image classification setting. Even with only 66 trainable parameters in the output layer, the model quickly overfitted on the unbalanced dataset. To our disappointment, this turned out to significantly reduce the classifier's performance from phase one of training.
We highly recommend going through these resources(In no particular order):
We hope you take away something from this report. When working in the realms of data science, one often stumbles upon imbalanced datasets. It would be great if we could impart some information to tackle the problem. We are open to suggestions and feedback via twitter.(@ayushthakur0 and @ariG23498).
Aritra Roy Gosthipaty and Ayush Thakur contributed equally to this report.