Predicting Lung Disease with Binary Classification on the NIH Chest X-ray Dataset
In this report, we will perform binary classification on the NIH Chest X-ray dataset. . Made by Ayush Thakur using Weights & Biases
Machine learning with medical imagery has been a promising domain for quite a while now. In fact, many in the field think ML-centric diagnoses are a matter of “when” not “if.” But since the consequences of false negatives and false positives are so detrimental for patients, the industry and researchers in this field are still fairly tentative.
Chest X-rays, like most medical images, are fairly ideal from a data perspective. They’re fairly uniform in size and angle and many are publicly available (with personally-identifying information redacted, of course).
Today, we’re going to look at if we can leverage an NIH dataset of those images to predict lung disease diagnoses. Specifically, here, our output is a prediction about whether we’re looking at a normal lung or an abnormal lung.
Task Performed: Binary Classification
Input Type: Image
Output: Prediction score denoting either normal or abnormal lung.
Let’s dig in:
NIH Chest X-ray Dataset is comprised of 112,120 X-ray images with 14 text-mined disease labels from 30,805 unique patients. The 14 diseases labels are Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, Fibrosis, Hernia, Infiltration, Mass, Nodule, Pleural Thickening, Pneumonia, Pneumothorax.
To create these labels, the authors used Natural Language Processing to text-mine disease classifications from the associated radiological reports. The labels are expected to be >90% accurate and suitable for weakly-supervised learning.
License and Attribution
With this model, the intent is to predict a given X-ray image as either normal(no disease-associated) or abnormal(have one or more diseases). This model is thus capable of performing binary classification.
Research: To further the research in the field of automatic Deep Learning-based “reading chest X-rays” for computer-aided diagnosis(CAD).
Pretrained Weights: To provide pre-trained weights for downstream tasks involving X-ray images.
Promote: To promote the use of model cards for reporting models.
Uses to avoid
51759 sample of the NIH Chest X-ray dataset is either labeled with one or more diseases(multi-labels). The label for such samples is converted to 1.
The remaining samples are labeled No Finding. The NLP-based labeling technique used by the authors of the dataset could not associate any disease with these samples. The label for such samples is converted to 0.
20,000 training images, 5000 validation images, and 10,000 test images were used to train, validate, and test the model:v0.
Preprocessing: The original image size is (1024 x 1024) pixels. They are resized to (256 x 256) pixels. The resized images are scaled-down.
model:v1 is trained from scratch with ResNet-50
as the backbone architecture.
The output of the Global Max Pooling is passed through a relu activated Dense network with 512 units. It is followed by a dropout layer(drop rate of 0.2). The output layer is sigmoid activated.
Training related specifics
Adam optimizer with a learning rate of 0.001 is used.
Cross-entropy loss is used.
Model is trained with early stopping.
Evaluation is done on the held-out test set. ROC Curve
and test error rate are used as evaluation metrics.
The Data_Entry_2017_v2020.csv that comes with the NIH Chest X-ray contains class labels as well as patient data. The patient data provided are:
Gender: Male or Female
Age: Continuous value
No signal about the age or the gender was provided during training.
Bias Towards Gender
The model is evaluated on the male-only(blue) as well as the female-only(orange) subset of the test data.
The model will give a better prediction for an X-ray belonging to the male category.
This shows the imbalance in the training dataset in the context of gender.
The bias is coming from the dataset.
Bias Towards Age Groups
The continuous ages are bucketed: [0, 10, 20, 30, 40, 50, 60, 70, 80, 90].
The model is evaluated for each bucket to learn about the model performance in each bucket.
Lung related sickness should be commonplace for certain age groups(mid-adult range).
The test error rate is high for the 0-20 age group which is acceptable.
For age groups 70-90, the number of data samples would be less.
This can be better quantified through domain knowledge adaptation.
# initialize wandb runrun = wandb.init()# download model_nih_1.h5 as artifactartifact = run.use_artifact('wandb/model-card-NIH-Chest-X-ray-binary/model:latest')artifact_dir = artifact.download()# close the runrun.join()