Optimizing CIFAR-10 Hyperparameters with W&B and SageMaker

Chris Van Pelt
11
Jan
2019

Everyone knows that hyperparameter sweeps are a great way to get an extra level of performance out of your algorithm, but we often don’t do them because they’re expensive and tricky to set up. AWS SageMaker makes it easy to do hyperparameter sweeps on your existing ML code, and W&B makes it effortless to see the results.

I had code for a CNN to classify images in the cifar-10 dataset, and I wanted to find the best set of hyperparameters.

Here’s a snippet from a working example where I used W&B with SageMaker.

estimator = PyTorch(entry_point="cifar10.py",
                   source_dir=os.getcwd() + "/source",
                   role=role,
                   framework_version='1.0.0.dev',
                   train_instance_count=1,
                   train_instance_type='ml.c5.xlarge',
                   hyperparameters={
                       'epochs': 50,
                       'momentum': 0.9
                   })

hyperparameter_ranges = {
   'lr': ContinuousParameter(0.0001, 0.001),
   'hidden_nodes': IntegerParameter(20, 100),
   'batch_size': CategoricalParameter([128, 256, 512]),
   'conv1_channels': CategoricalParameter([32, 64, 128]),
   'conv2_channels': CategoricalParameter([64, 128, 256, 512]),
}

SageMaker will spin up an AWS instance for each hyperparameter value and train the model. W&B tracks everything that happens and makes it easy to visualize the sweep. Here’s a table in W&B where I’m tracking all the runs that ran in the sweep, sorted by test accuracy. The test accuracy ranges from 10 - 76.45%, depending on the hyperparameters.

To dig deeper into the patterns between hyperparameters and accuracy on different classes, we generated a parallel coordinates plot in W&B. The first five columns are the configuration parameters and the far right column is the test accuracy. Each line corresponds to a single run. I’ve highlighted the runs with the best test accuracy to see where they land on the other columns. I discovered that lower learning rate (config:lr) and fewer hidden nodes (config:hidden_nodes) correlated with higher test accuracy.

Parallel coordinates plot highlighting runs with the highest test accuracy

We can zoom in on individual accuracy metrics or even compare our different models’ classification on a single image. Here are the top 10 models overall, struggling to identify this picture of a dog correctly.

To demonstrate the integration, we setup a sweep example in wandb over the cifar-10 dataset using pytorch. If you want to reproduce this, I put my code on Github. For more info on the integration check out our docs.