Skip to main content

Getting Started with the Microsoft Rice Disease Classification Challenge

Building a model to classify rice disease using FastAI
Created on May 17|Last edited on May 24
Zindi has just launched the Microsoft Rice Disease Classification Challenge in which participants are tasked with creating a model that can identify different diseases in images from Egyptian rice fields. Early detection and identification of these diseases can help improve yields for farmers, so this is an important task!
In this post, we'll get started with an entry into this competition and show how you can build on this baseline and begin climbing the leaderboard. To follow along, head over to the competition page to register and download the data and then launch this notebook which contains everything you'll need to try the code for yourself and submit an entry.

Preparing the Data

The dataset consists of images (both regular RGB images and Red-Green-NearInfrared (RGNIR) images) along with some CSV files containing image ids and labels for the training set. For this first attempt, we'll focus on the RGB images only. We can load the labels into a pandas dataframe and then prepare our Fastai dataloaders like so:
Loading the images using Fastai's ImageDataLoaders.from_df function.
In addition to simply loading in the images, this little snippet also applies some data augmentation to our data. This includes transforms like random horizontal flips, brightness and contrast tweaks, warps, zooms and so on, which help to reduce the chances of overfitting by providing the model with more diverse inputs during training. This is especially useful when we have a relatively small dataset such as this one.

Training a model

Our goal here is to get up and running quickly, so we'll start with a pre-trained ResNet model and leverage Fastai's fine_tune() function to train it using some sensible defaults:
Fine-tuning a pretrained model
Behind the scenes, these two lines of code do quite a bit of work for us:
  • Download the weights of a pre-trained model, in this case, one with a resnet34 architecture
  • Replace the final few layers of this model with a new classification head, which has the right number of outputs for our task (deduced from the dataloaders we created)
  • Train this new head for one epoch while keeping the rest of the network frozen
  • Unfreeze the model and train it for another four epochs at a lower learning rate
As you can see, within a few minutes we've reached nearly 90% accuracy on this task!

Sanity Checks

It is always important to do a few 'sanity checks' when training a model like this, to make sure we haven't messed up something in our pipeline. For example, it can be useful to view example predictions with learn.show_results() or, even better, to see some examples where the model makes incorrect predictions:
The plot_top_losses function in action. This is an excellent way to check for labelling issues etc.
Another useful tool here is the confusion matrix. A confusion matrix shows how many samples from each class are assigned a given label by the model. This can highlight pairs of classes that the classifier finds particularly difficult to separate, or if there are any issues with data balance such as a particular class being under-represented in the data.
A confusion matrix - most samples are correctly classified
In this case, we see that the model correctly predicts the right label in most cases, and the top losses show some particularly hard examples where the misclassifications seem reasonable. There is one image that looks over-exposed, hinting that we might perhaps reduce any contrast/brightness-related data augmentation steps in case this is a result of the aug_transforms settings used.

Making a Submission

Once you're happy with the model, it's time to make predictions on the test set. Zindi provides the image ids in the same format as the training data, but with the label column removed. This means we can create a test dataloader with the same processing steps as our training dataloader with dls.test_dl(test) and get the predictions with learn.get_preds. The notebook shows the full process for preparing the submission file in the format Zindi requires.
Don't hit 'submit' just yet though - Fastai has one final trick up its sleeve: 'Test Time Augmentation'. We can swap in tta(...) in place of get_preds(...) and Fastai will use the same data augmentation strategies we used during training to create multiple versions of each image in the test set and average the predictions across these different variants. It's slightly slower but tends to give a final boost in performance.

Next Steps

Our entry here scores ~0.35, 16th place at the time of writing and about equal to the Zindi baseline. Not bad for ~2 minutes of training! There are various ways you could improve on this, such as
  • Trying a larger model
  • Training for longer
  • Optimizing hyperparameters such as learning rate
  • Exploring different types of data augmentation
  • Experimenting with including the multispectral images
  • Ensembling multiple models together
  • ...
It's tempting to dive right in, but it's also very easy to get lost playing with all of these different tweaks. So, in the final section, we'll look at how to keep track of your experiments and quickly dial in a winning strategy.

Experiment Tracking

Good news: experiment tracking doesn't have to be hard! In the final section of the notebook, I followed this great guide to add Weights and Biases tracking. Most of the work is done by passing in cbs=WandbCallback() when creating the learner, which automatically tracks model parameters (such as learning rate), training stats (loss, epochs...) and validation metrics (validation loss, error rate and so on).


For each experiment ('run') you call wandb.init() passing in any configuration or parameters you want to keep track of. No more guesswork trying to remember which batch size worked best! And speaking of batch size, W&B automatically logs system metrics as well. My first few runs were using <20% of the GPU memory, a clear hint that batch size or image size could be comfortably increased. The system stats also flagged another issue: when using the NRG images the CPU was pegged at 100% while GPU utilization was under 5%, which it turns out is due to the larger image size of the multispectral images slowing down the dataloader.


For this demo, I tried changing to using the NGR images (worse, and slower), training for longer using progressive resizing (pretty good) and switching to a larger densenet201 model (very good). The result is a model that has >96% accuracy, and scores ~0.14 on the leaderboard (9'th place at the time of writing). You can see from the confusion matrix above that with this model there are only a handful of misclassified images in each category.

Conclusions

There are many more places where some additional performance could be squeezed out, but they are left as an exercise for the reader ;) I hope this post has given you the motivation you need to dive in and get started, and the tools required to track your improvements and climb up the rankings. If you have any questions or comments, feel free to reach out to me @johnowhitaker on Zindi, Twitter, Gmail etc. See you on the leaderboard!