Deep learning holds great promise for the medical field. A computer model can assist doctors with references or prioritization as they diagnose patients. When professional expertise is scarce, the model can make a best guess based on the statistics of all the patient history it has seen. In a concrete example from 2017, Pranav Rajpurkar et al from the Stanford ML Group developed CheXNet, a deep convolutional neural network which—in certain controlled contexts—claims to diagnose pneumonia from chest x-rays more accurately than an average radiologist.
This impressive result relies on the NIH Chest X-ray Dataset of 112,120 x-rays from 30,805 patients, released as ChestX-ray8 by Wang et al. at CVPR 2017. The figure above is from this paper and illustrates some of the difficulties of framing this as a computer vision problem:
The ground truth labels are text-mined from associated radiological reports, which is an inherently approximate process. Some radiologists have challenged these labels as significantly—moreover, systematically—worse than the >90% accuracy claimed for this weakly-supervised approach. Still, this is one of the largest open medical imaging datasets available. Here, I explore the dataset and existing classification approaches, try simple baselines, and outline some strategies for training models in more realistic healthcare data settings, which I hope to explore in future reports:
I start with a more manageable subset of the full data: a random 5% sample, or 5,606 images, available from Kaggle. Most of the conditions are very hard to detect with an untrained eye, but it's helpful to look at more of the images.
More than half of the data is marked "no finding" (though the dataset notes this label is noisy and includes undetected conditions/patients who are diagnosed on a later visit). "Infiltration" is the second largest class, and the tail of more rare conditions, including "Pneumonia", is long. Training in binary framing may be able to compensate for this (i.e., learning a binary classifier for each label X and treating all other examples (without the label X) as negative for that classifier). Here is a visualization of the skew from an excellent replication attempt in Pytorch by evakli11 and huntforgz.
Initial experiments follow a popular Kaggle notebook by Paul Mooney, training a simple CNN on all labels or on the pathologies alone. In this proof of concept, the model doesn't learn much initially and converges to a validation accuracy proportional to the most frequent class in the data. The confusion matrix confirms this: the model simply guesses the most frequent class every time, and no fancy modifications to the optimizer, batch norm, or filter size make a difference at this stage. How much can we improve on this basic model without adding complexity or data?
Training on the natural, highly-skewed distribution saturates quickly. Naive convolutional models trained on the unbalanced data simply predict the most frequent class. How can we account for the long tail and improve accuracy on a small model, without adding more data?