Monitor Your PyTorch Models With Five Extra Lines of Code

Lukas Biewald

I love PyTorch and I love tracking my experiments.  It’s possible to use Tensorboard with PyTorch but it can feel a little clunky.  We recently added a feature to make it dead simple to monitor your PyTorch models with W&B!

I started with the PyTorch cifar10 tutorial. This tutorial is fantastic but it uses matplotlib to show the images which can be annoying on a remote server, it doesn’t plot the accuracy or loss curves and it doesn’t let me inspect the gradients of the layers.  Let’s fix all that with just a couple lines of code!

At the top of my script I add the lines:

import wandb

This starts a W&B process that tracks the input hyperparameters and lets me save metrics and files. It also saves the stdout, stderr and tracks my GPU usage and other system metrics automatically. Here's a pair of graphs of GPU usage and temperature from one of my runs:

Now I can add a log command at the end of each epoch and easily see how my network is performing on each class:

  for i in range(10):
       print('Accuracy of %5s : %2d %%' % (
           classes[i], 100 * class_correct[i] / class_total[i]))
       class_acc["Accuracy of %5s" % (classes[i])] = 100 * class_correct[i] / class_total[i]


At the end of each epoch I log a couple example images to get a feel for what my network is doing.  I can log these images just like metrics.  In fact I can log matplotlib graphs in the same way, but that’s a topic for another blog post!

example_images = [wandb.Image(image, caption=classes[predicted]) for image, predicted, label in zip(images, predicted, labels)]

wandb.log({"Examples": example_images})

After I define my network, I use this magical command:

Now I get a histogram of each gradient in my network as it trains!

Each new PyTorch run is added to my table, which I can explore to find deeper patterns.

Give it a try!  You can run my sample cifar10 code, and here are the results of my run on Weights & Biases.

Join our mailing list to get the latest machine learning updates.