X-Ray Illumination
Introduction
Deep learning holds great promise for the medical field. A computer model can assist doctors with references or prioritization as they diagnose patients. When professional expertise is scarce, the model can make a best guess based on the statistics of all the patient history it has seen. In a concrete example from 2017, Pranav Rajpurkar et al from the Stanford ML Group developed CheXNet, a deep convolutional neural network which—in certain controlled contexts—claims to diagnose pneumonia from chest x-rays more accurately than an average radiologist.
This impressive result relies on the NIH Chest X-ray Dataset of 112,120 x-rays from 30,805 patients, released as ChestX-ray8 by Wang et al. at CVPR 2017. The figure above is from this paper and illustrates some of the difficulties of framing this as a computer vision problem:
- there are many lung pathologies besides pneumonia, which co-occur at different rates. There are 21 distinct conditions in this dataset, and multiple pathologies are more common than single pathologies. For example, "Infiltration", "Effusion", "Atelectasis", and "Mass" are marked together most often, but all of these labels can also be marked alone or occasionally associated with other labels like "Cardiomegaly" and "Pneumothorax".
- diagnoses normally require assessing the rest of the patient (medical history, symptoms, etc) and are not determined by a single x-ray image
- image alignment, quality (contrast, focus), and annotations (text or symbols added by the imaging setup) vary substantially in the dataset
Training on small and noisy ground truth
The ground truth labels are text-mined from associated radiological reports, which is an inherently approximate process. Some radiologists have challenged these labels as significantly—moreover, systematically—worse than the >90% accuracy claimed for this weakly-supervised approach. Still, this is one of the largest open medical imaging datasets available. Here, I explore the dataset and existing classification approaches, try simple baselines, and outline some strategies for training models in more realistic healthcare data settings, which I hope to explore in future reports:
- using less data, namely a random 5% sample of the full dataset
- training on patient metadata beyond the image itself and learning representations invariant to this metadata
- address class-imbalance, or the long tail of rare conditions with few examples
The Dataset: Random 5% subsample
Simple CNN Baseline
Next steps
Training on the natural, highly-skewed distribution saturates quickly. Naive convolutional models trained on the unbalanced data simply predict the most frequent class. How can we account for the long tail and improve accuracy on a small model, without adding more data?
Small, long-tailed data; basic model
- balanced training: create a balanced split of the dataset
- image cropping to exclude the annotations/edges
- feed in extra patient metadata like age, gender, etc for an embedding
- multi-label classification scenario: binary cross-entropy, one positive label versus the rest as negatives for each class
- facebook classifier-balancing approach: repo and paper
- use a MAML model with something like 100-shot learning
Increasing data size; more complex model
- increase training data to incrementally to trade off bias
- replicate findings from CheXNet: finetune a DenseNet-121 pretrained on ImageNet, optionally retrain from scratch. This notebook defines a DenseNet in Keras.
- compare to state of the art/repositories made after CheXNet
- explore CAM visualizations of diagnosis-relevant areas