Skip to main content

Nature Photo Image Classification with Keras + Weights & Biases

Versioning and interactively exploring image data and class predictions with Weights & Biases
Created on January 17|Last edited on January 18


Overview

This is a walkthrough of Tables for visualization and Artifacts for versioning deep learning models in Weights & Biases. As an example, I fine-tune a convolutional neural network in Keras on photos from iNaturalist 2017 to identify 10 classes of living things (plants, insects, birds, etc).

Project workflow

  1. Upload raw data
  2. Create a balanced split (train/val/test)
  3. Train model and validate predictions
  4. Run inference & explore results
We’ll start by uploading our raw data, then split that data in train, validation, and test sets before spending the bulk of our time digging into training our model training, validating predictions, running inference, and exploring our results. In Artifacts, we can see the connections between our datasets, predictions, and models. Our workflow looks like this:
Inputs and outputs, from Weights & Biases artifact DAG

Model training and validation

As a quick and simple example, we use a pre-trained Inception-V3 network and fine-tune it with a fully-connected layer of variable size.

Model performance across runs



Run set
13


Checking model precision and confounding images

After every epoch of training, we log the predictions on the validation dataset to a Table. We experiment with different hyperparameters—the size of the tuning layer, the learning rate, the number of epochs —and compare the validation predictions across model variants.

Run set
13

When the model guesses a particular class, what is the distribution of true labels for those guesses? In this variant, we again see that "Mollusks" are a popular confound for "Animals" (second row, "truth" column). Interestingly they're also the top confound for "Fungi". Scrolling through some of the images, perhaps snails shells on brown backgrounds or the bright colors of sea slugs against a dark sea are easily confused for mushrooms?
Grouping by guess helps us see potential systematic errors per class, for example we see that Mollusca is the class certain models incorrectly guess most often for Amphibia.

Run inference and explore results

Now let's switch to view predictions on test data instead of validation data.

Run set
13



System performance

With wandb logging system metrics every 2 seconds, we have charts below to help monitor network traffic, disk utilization, and other valuable (and expensive) compute resources.

Run set
13