COVID-19 Research Project using PyTorch and Weights & Biases

Ayush Chaurasia, Contributor

COVID-19 or Coronavirus has taken the world by storm. At the time of writing this article, coronavirus has already been declared a pandemic by the WHO. Some of the world’s best research institutes are trying to develop vaccines to check the spread. Deep learning researchers are also hard at work to develop systems that can assist in the detection of infected patients.

In this tutorial, I’ll provide a boiler-plate for anyone who'd like to engage in research on COVID-19 datasets. We’ll walk through the process of readying the dataset, setting up early phase experiments and honing in on the best performing model through hyperparameter optimization.

View COVID-19 Scans in live dashboards→

The Dataset

The dataset was compiled by Adrian Rosebrock of pyimagesearch and consists of 25 chest X-rays of COVID-19 patients, as well as 25 chest X-rays of healthy patients. This is a “deep learning in radiology” problem with a toy dataset. We’ll use pytorch lightning, which is a high-level wrapper around  the pytorch library. You can learn more about pytorch lightning and how to use it with Weights and Biases here.

Let’s load the dataset using pytorch lightning:

import pytorch_lightning as pl
class Classifier(pl.LightningModule):
 def train_dataloader(self):
   transform_data = transforms.Compose([transforms.Resize((224,244)),     transforms.ToTensor()])
   data = torchvision.datasets.ImageFolder('./dataset', transform= transform_data)
   train_size = int(0.8*len(data))
   test_size = int(len(data) - train_size)
   self.train_dataset, self.test_dataset =,     (train_size, test_size))
   train_loader =,batch_size=16)
   return train_loader
 def val_dataloader(self):
   test_loader =, batch_size=16)
   return test_loader

Here, we have overridden the train_dataloader() and val_dataloader() defined in the pytorch lightning. Now these functions will be used by the Trainer load the training set and validation set. We have divided the dataset into 80-20 batch where 80% of the data will be used for training and 20% of the data will be used for validation.

Let’s have a look at the sample of the dataset.

The X-ray on the left is of a healthy person and the one on the right is of a COVID-19 patient.

Setting Up Our Experiment with Sweeps

Our intention here is to try different models and optimize their hyperparameters to find the best model for our use case. Instead of manually trying out different hyperparameters, we can easily set up a Weights and Biases sweeps to automate the process.

First, we need to specify the parameters that we’re going to sweep along with their possible values. Let’s define that in the dictionary. We’ll also define the default values of these hyperparameters.

sweep_config = {
   'method': 'random', #grid, random
   'metric': {
     'name': 'val_accuracy',
     'goal': 'minimize'  
   'parameters': {
       'learning_rate': {
           'values': [0.1, 0.01,0.001]
       'optimizer': {
           'values': ['adam', 'sgd']
config_defaults = {
       'learning_rate': 0.001,
       'optimizer': 'adam',
       'model' : 'resnet18'

Here, I’ve chosen VGG16 and Resnet-18 models for our datasets because we don’t have a large dataset.

Building the classifier class

Here we’re loading the pre defined the models from torchvision. We need to change the last layer in both the networks to output only 2 neurons as this is a binary classification problem.

   def __init__(self):
       super(Classifier, self).__init__()
       if wandb.config.model == 'resnet18':
           self.model = torchvision.models.resnet18()
           self.model.fc = torch.nn.Linear(512,2)
       if wandb.config.model == 'VGG16':
           self.model = torchvision.models.vgg16()
           self.model.classifier[6] = torch.nn.Linear(4096,2)      

The forward function is straight-forward.  We’re just calling the previously created model. The other function is for simply calculating the cross entropy loss.

def forward(self, x):
   x = self.model(x)      
   return x

def cross_entropy_loss(self, logits, labels):
   return F.cross_entropy(logits, labels)

I have omitted the calculation part from the above code as we’re going to focus on the logging. The entire code is available in this github repo. Here, we’ve logged the training loss as well as the validation loss directly to the weights and biases dashboard.

def training_step(self, train_batch, batch_idx):
   Perform the training pass
   logs = {'train_loss': loss}
   return {'loss': loss, 'log': logs}

def validation_step(self, val_batch, batch_idx):
   Perform the validation operation
   return {'val_loss': loss}

def validation_end(self, outputs):
   Average out the validation error
   logs = {'val_loss': avg_loss}
   return {'avg_val_loss': avg_loss, 'log': logs}

Finally, we’ll define the optimizer function to return an optimizer of our choice.

   def configure_optimizers(self):
       optimizer = torch.optim.Adam(self.parameters(),lr=wandb.config.learning_rate)
       if wandb.config.optimizer == 'sgd':
           optimizer = torch.optim.SGD(self.parameters(),lr=wandb.config.learning_rate)
       #optimizer =  torch.optim.SGD(self.parameters(),lr=0.01)
       return optimizer

Hyperparameter Optimization with Sweeps

Now we’re ready to sweeps through all the possible combinations of models and their hyperparameters.

def train():
   model = Classifier()
   trainer = pl.Trainer(max_epochs = 10 )

Here we’ve made an instance of the classifier class and loaded the data to set up the trainer.

def train():

Now let’s run the sweep:


Here is the Sweep page:

Now we have the boiler-plate code for COVID-19 research. Although the dataset that we’ve used here isn’t nearly enough for building systems ready for production, this code structure can nevertheless be used with any dataset released in the future by hackathon organizers or research institutes. So, go ahead and assemble your own dataset and try to build a COVID-19 detector using weights and biases.

Visualizing Sweeps Performance

Let’s have a look at the training and validation accuracy for all the runs in the sweep –

The visualizations generated by Weights & Biases are highly customizable - you can set the range of values you'd like to plot, add a smoothing factor, create custom expressions, group runs, add custom legends, colors, filter out runs by custom criteria and pick the number of runs to plot, which is set to 10 by default. I’ve included all the 12 runs. I’ve also changed the upper limit of Y-axis to prevent the chart from getting affected by the unusually high loss value that occurs when gradient descent overshoots.

Sweeps Visualizations

There are other useful visualizations produced by sweeping through the runs.

Parallel Coordinates Plot

The parallel co-ordinates plot shows you the correlation amongst the various hyper-parameters and the metric that we’re trying to optimize (val_loss). Here we can see high loss values are correlated to VGG and higher learning rates.

Hyperparameter Importance Plot

This chart plot surfaces which hyperparameters were the best predictors of, and highly correlated to desirable values for your metrics (in this case the val_loss).

If you also want to determine the statistical importance of the hyperparameters, that can be easily done using the ‘parameter importance’ chart that is also generated automatically. Here we can see that resnet is negatively correlated to high loss values, whereas learning rates are positively correlated. So we want to use resnet and lower learning rates to optimize for lower loss. You can get both these plots by clicking on Add Visualization on your project page.

Weights and Biases Reports

Using weights and biases reports is a fast and efficient way to share your insights and analysis, including all of the visualizations that have been generated in the dashboard. You get to control the level of granularity of your report as you can pick which visualizations make sense in context, and add sections to dive deeper into the code if needed. Then you can easily share the report in LaTeX format!

Here’s a quick report that I generated using Weights & Biases and it didn’t take me more than 2 minutes to finish it up by adding the descriptions. Here’s the link - COVIDExampleReport.

I’m pretty sure that a lot more COVID-19 datasets are going to be made public. Various competitions have already announced the dates of release for COVID-19 datasets. So, go ahead and make tweaks to the code, test it with the dataset of your choice and make your reports to share your models, and the insights you’ve gained with the machine learning community!

Join our mailing list to get the latest machine learning updates.