Skip to main content

How to Train a Model for Chest X-Ray COVID Diagnosis

In this article, we explore my mini-project for the Weights & Biases MLOps Course, in which we learn how to train a model for Chest X-Ray COVID-19 Diagnosis.
Created on January 30|Last edited on February 25
Note: This is a community submission from Max, who created this project while taking our Effective MLOps course. It's available free on demand if you'd like to check it out!
Recent research suggests that chest X-rays hold significant information regarding the COVID-19 virus. Since there are a finite amount of radiologists and expert clinicians available to look at these X-rays and because X-rays are fairly uniform in nature (vs. other image types like personal photographs), leveraging fast, accurate models can dramatically improve patient outcomes and the load on overburdened technicians.
Here's what we're going to look at:

Table of Contents



About Our Dataset

The original dataset comes from Kaggle and is organized into 2 folders (train, test). Both train and test contain 3 subfolders (COVID19, NORMAL, PNEUMONIA). Our model will look to classify X-rays into those three categories. Additional info:
  • The total train dataset contains total 6432 X-ray images, with 606 images belonging to COVID19, 1266 images belonging to NORMAL and 4273 images belonging to PNEUMONIA.
  • The dataset is imbalanced towards mainly the PNEUMONIA and also a bit to the NORMAL class.
  • To speed up the training process, there's a smaller partition of the dataset has with just 1500 images (500 per class). 


In the chest X-ray of COVID-19 patients, both lungs may display consolidation in the peripheral regions (bottom of the lung) which appears as a "white lung."
💡

Split Dataset (Train/Val/Test)

Randomly partitioning the dataset into sets can result in having more or less samples of a specific class in the test dataset, resulting in an unfair comparison when evaluating on the test set. We'll perform a class-balanced 80%/10%/10% split on the dataset for train, val, and test respectively. You can get a feel for the dataset in this W&B Table:

Run set
7


Data Analysis

Upon examining the dataset, I noticed variations in image size and zoom ratio (for example: some lungs look big, and some lungs look small in the images). Some images also contained text and notations, while others were overexposed, a different color, or blurry. Here's an example of some of the less than ideal images:
Some challenging images that can confuse our model, as some are even difficult for humans to draw conclusions from.
It may be worthwhile to remove these challenging images or use additional data augmentation techniques to improve model robustness and ability to generalize to real-world scenarios. We'll press ahead though.

Data Pre-Processing & Augmentation

To enhance the model's robustness and avoid overfitting, data augmentation was employed using Albumentations which support a wide variety of image augmentations. The images were resized to (320, 300), which is the required input size for our model.
For a simple baseline model, we only used a horizontal flip with a 50% chance as data augmentation. However, it might be interesting to explore a few more image augmentations to deal with varying quality of X-ray images present in the dataset.
import albumentations as A
from albumentations.pytorch import ToTensorV2

self.transform = A.Compose([
A.Resize(height=img_size[0], width=img_size[1]),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])

Training our Baseline Model

Using the Pytorch Lightning Trainer's integration with Weights & Biases, we only need to use the wandblogger callback.
# log training progress to wandb
wandb_logger = WandbLogger(project=CONFIG.project, log_model=True)

# Log additional config parameters
wandb_logger.experiment.config.update(CONFIG)

# Initialize a trainer
trainer = Trainer(logger=wandb_logger,
...)
The configurations are also logged to easily track our settings for each experiment:
Config(project='mlops-course-X-ray', # Define project name for logging to W&B
seed=42,
epochs=20,
data_file_path='./artifacts/data-xray-tiny-split-v0',
evaluate='test', # evaluate model option('val', 'test')
num_workers=4,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
model_name='tf_efficientnetv2_b3',
pre_trained=False,
fine_tune=False,
num_classes=3,
img_size=(320, 300),
batch_size=32,
num_batches=37, # No. of train batches
optimizer='AdamW', # AdamW, SGD, RAdam, Ranger21
learning_rate=1e-3,
lr_scheduler=None, # None, CosineAnnealingLR, CosineAnnealingWarmRestarts, CustomCosineAnnealingWarmUpRestarts
momentum=0.9,
weight_decay=5e-6,
label_smoothing=0.0,
)

Training Analysis

From the training results, we can see that the baseline model quickly converges after 5 epochs on our evaluation metrics (accuracy, precision, recall and F1 score).
When evaluating the overall performance of a classifier, we use the precision, recall, and F1 score macro average scores to ensure equal treatment of all classes and fair evaluation in the case of imbalanced datasets.
The best models in terms of F1 score on the validation set are logged. From the metrics on the validation set, we can see that there are a lot of spikes, including a great peak reaching an F1 score of 96%.

Run set
7



Run set
7


Baseline Evaluation

The best model with an F1 score of 0.96 on the validation set is evaluated on the test set. A confusion matrix has been made to get insights into how well the model predict on new data.

Run set
2


Sweep Configuration

We will perform a hyperparameter search with W&B Sweeps to optimize our model. In this project, we will do a small sweep where we experiment with two different models and optimizers as specified in the sweep configuration file. For all runs in our sweep, I have used label smoothing of 0.1, a cosine annealing learning rate scheduler and extra augmentation of the zoom ratio.
method: grid
metric:
goal: maximize
name: val/accuracy
name: sweep
parameters:
model_name:
values:
- tf_efficientnetv2_b3
- efficientnetv2_s
optimizer:
values:
- RAdam
- SGD

Sweep Analysis

We created a parallel coordinates plot to visualize the outcome of the sweep, using two metrics as evaluation criteria:
  1. Test accuracy to check how well it generalizes to new data.
  2. Individual test accuracy of COVID-19 since we want to distinguish COVID from the other classes
We can see that efficientnetv2_s model greatly improves performance up to 5% on the test accuracy and using SGD optimizer slightly improves performance of 1% on the test accuracy.
Note that efficientnetv2_s haven't run with SGD due to out of memory error. However, it is expected that efficientnetv2_s with SGD would generalize even better on new data.
💡

Sweep: sweep 1
3


Model Evaluation

The results of the sweep show that the challenger model (efficientnetv2_s + RAdam) 🏆 slightly outperforms the baseline in terms of test accuracy (0.9533 vs 0.9467), particularly in detecting Pneumonia, as seen in the confusion matrix below. This is quite interesting because the baseline model had a higher maximum validation accuracy (0.96 vs 0.9267). This suggests that the challenger model is more likely better at generalizing to new data.
The individual test accuracy of COVID-19 were the same for both models. However, the challenger has more false negatives (2 vs. 1) but less false positives (1 vs 2). In the task of detecting COVID-19 from X-ray images, we prioritize minimizing false negatives to avoid missing a COVID-19 diagnosis, which could have potentially life-threatening consequences.
These conflicting results make it difficult to determine which model is superior, as both have their own strengths and weaknesses ⚖️.
💡

Run set
7


Error Analysis

The error analysis is a crucial step in understanding the strengths and weaknesses of the model and identifying areas for improvement. To do this, we can make a prediction table that includes the input image, visual explanations using Grad-CAM, labels, predictions, and scores. This table allows us to check the images where the model made mistakes, how confident it was in doing so and where the model has looked to make such a prediction.
Explainable AI (XAI) is crucial in gaining insight into how AI models make predictions, particularly in high-stakes applications such as medical diagnosis. In the case of image classification using Convolutional Neural Networks (CNNs), visual explanations show where the CNN is "looking" at in the image and what it considers as relevant features for the prediction. This understanding is crucial for ensuring the trust and transparency of the AI model, as well as for debugging and improving its performance.
💡

Challenger vs Baseline

While the baseline model achieved solid results in distinguishing COVID from other diseases with only 1 false negative and 2 false positive out of 50 cases on the test set, the visual explanations of the baseline model sometimes highlight areas where text appears in the images (red represents important regions, while blue represents unimportant regions). This suggests that the model may not have learned to properly identify and locate the parts of the lungs affected by the COVID virus.
In a real-life scenario, we would handle raw data that doesn't contain pre-written text, implying that there is some type of data leakage as we gain knowledge about the future.

Run set
1

The challenger from the sweep performs similarly to the baseline model, making it challenging to determine which is superior. However, upon examining the visual explanations, we can observe significant improvement in the localization of regions inside the lungs. Because I am not a radiologist, it's difficult for me to say that where the model is looking correspond to the important areas to determine COVID from these X-ray images.

Run set
1

By comparing the visual explanations of both models for a specific input image, we can conclude that the challenger is more reliable in making accurate predictions as it isn't biased to use the text information present in the image. Through the use of explainable AI, we can differentiate between the strong and weak classifier, even if they have similar scores on evaluation metrics. This helps to establish trust in our model and its ability to work with real-world data.
The two models both do the same job to make a correct prediction, but differ in their approach to reach that result.

What's Next?

For future work, there are many different areas to explore how to optimize our model:
  • Most important would be to gain knowledge how well the model is doing to localize the decisive areas by talking to a radiologist.
  • Train the model on the original dataset.
  • Remove the the text from the images that contain future information.
  • Obtain additional images that are similar to the ones where the model's prediction was incorrect.
  • Remove the most challenging images (e.g. too much blur, overexposure and underexposure) and use more image augmentations (e.g. zoom ratio, rotate around 10-15 degrees, etc. ).
  • Perform a larger hyperparameter sweep to experiment with learning rate, batch size, weight decay, learning-rate schedulers and label smoothing.
Iterate on AI agents and models faster. Try Weights & Biases today.