Skip to main content

How to Integrate PyTorch Lightning with Weights & Biases

A quick tutorial on integrating Lightning with W&B, complete with executable code and interactive visualizations
Created on March 14|Last edited on November 30

Introduction

In this tutorial, we’re going to run you through a few quick steps to integrate Weights & Biases with PyTorch Lightning. And while it takes just a couple lines of code to get started...
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger()
trainer = Trainer(logger=wandb_logger)
...the integration allows you to not only train, monitor, and reproduce your models but also:
  • log your configuration parameters
  • log your losses and metrics
  • keep track of your code
  • log your system metrics (GPU, CPU, memory, temperature, etc)
  • visualize your data in customizable tables
  • share and collaborate across your team & organization
And that's just the tip of the iceberg. Still, with that said, this tutorial is certainly not meant to be exhaustive but rather get you started using W&B with PyTorch Lightning. We'll be using some simple MNIST experiments for our backbone here and, if you'd rather read this or follow along in a colab with executable code, we've got you covered there too.
Lastly, before we begin in early, if you’re new to either W&B or PyTorch Lightning, here’s a quick explainer of both:
Weights & Biases is the premier developer-first MLOps platform. With W&B, you can save everything you need to debug, compare and reproduce your models — architecture, hyperparameters, git commits, model weights, GPU usage, and even datasets and predictions.
Pytorch Lightning, meanwhile, is a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training, 16-bit precision or gradient accumulation.
You can read more about the integration in PyTorch Lightning’s docs or our own.
And with all that preamble out of the way, let’s get started in earnest, shall we?

Installing & Setting Up W&B

First, let's install our Lightning integration
!pip install -q pytorch-lightning wandb
Then take a moment to double-check that our experiments will be associated with our own personal accounts:
import wandb
wandb.login()

Setting up Our Dataloader

For the purposes of our tutorial, we'll be using a vanilla PyTorch dataloader and the canonical MNIST dataset. This code brings in that dataset and lets us transform it for our experiments:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])

dataset = MNIST(root="./MNIST", download=True, transform=transform)
training_set, validation_set = random_split(dataset, [55000, 5000])
--
training_loader = DataLoader(training_set, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=64)

Defining Our Model

Ok, so now that we’ve installed our dependencies and brought in our data, we’re ready to train our model and visualize our work on W&B. The code below is going to define our models, but we’ll show you how to save your model checkpoints, track experiments, and more below. First though:
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule

class MNIST_LitModule(LightningModule):

def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
'''method used to define our model parameters'''
super().__init__()

# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = Linear(28 * 28, n_layer_1)
self.layer_2 = Linear(n_layer_1, n_layer_2)
self.layer_3 = Linear(n_layer_2, n_classes)

# loss
self.loss = CrossEntropyLoss()

# optimizer parameters
self.lr = lr

# save hyper-parameters to self.hparams (auto-logged by W&B)
self.save_hyperparameters()

def forward(self, x):
'''method used for inference input -> output'''

batch_size, channels, width, height = x.size()

# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)

# let's do 3 x (linear + relu)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.relu(x)
x = self.layer_3(x)

return x

def training_step(self, batch, batch_idx):
'''needs to return a loss from a single batch'''
_, loss, acc = self._get_preds_loss_accuracy(batch)

# Log loss and metric
self.log('train_loss', loss)
self.log('train_accuracy', acc)

return loss

def validation_step(self, batch, batch_idx):
'''used for logging metrics'''
preds, loss, acc = self._get_preds_loss_accuracy(batch)

# Log loss and metric
self.log('val_loss', loss)
self.log('val_accuracy', acc)

# Let's return preds to use it in a custom callback
return preds

def test_step(self, batch, batch_idx):
'''used for logging metrics'''
_, loss, acc = self._get_preds_loss_accuracy(batch)

# Log loss and metric
self.log('test_loss', loss)
self.log('test_accuracy', acc)
def configure_optimizers(self):
'''defines model optimizer'''
return Adam(self.parameters(), lr=self.lr)
def _get_preds_loss_accuracy(self, batch):
'''convenience function since train/valid/test steps are similar'''
x, y = batch
logits = self(x)
preds = torch.argmax(logits, dim=1)
loss = self.loss(logits, y)
acc = accuracy(preds, y)
return preds, loss, acc

model = MNIST_LitModule(n_layer_1=128, n_layer_2=128)
And with that, our model is ready! A quick note on the code above! You can:
  • Call self.save_hyperparameters() in __init__ to automatically log your hyperparameters to W&B
  • Call self.log in training_step and validation_step to log the metrics

Saving Model Checkpoint

Saving model checkpoints is as easy as you’d expect (i.e. just a couple lines of code). Note: the ModelCheckpoint callback is required along with the WandbLogger argument to log model checkpoints to W&B.
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')

Tracking Experiments with WandbLogger

PyTorch Lightning has a WandbLogger that lets you easily log your experiments with Weights & Biases. Just pass it to your Trainer to log to W&B. You can check out the WandbLogger docs for all parameters.
Note: to log the metrics to a specific W&B Team, pass your team name to the entity argument in WandbLogger
Here are a few arguments that you might need (though you can click the docs above for more)
  • Log models: WandbLogger(... ,log_model='all') or WandbLogger(... ,log_model=True)
  • Set custom run names: WandbLogger(... ,name='my_run_name')
  • Organize runs: WandbLogger(... ,project='my_project')
  • Log histograms of gradients & parameters: WandbLogger.watch(model)
  • Log hyperparameters: Call self.save_hyperparameters() within LightningModule.__init__()
  • Log custom objects (like images, videos, or molecules): Use WandbLogger.log_text, WandbLogger.log_image and WandbLogger.log_table
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger(project='MNIST', # group runs in "MNIST" project
log_model='all') # log all new checkpoints during training

Logging Images, Text, and More with WandbLogger

Pytorch Lightning is extensible through its callback system. That means we can create a custom callback to automatically log sample predictions during validation.
WandbLogger provides convenient media logging functions:
  • WandbLogger.log_text for text data
  • WandbLogger.log_image for images
  • WandbLogger.log_table for W&B Tables.
An alternative to self.log in the model class is directly using wandb.log({dict}) or trainer.logger.experiment.log({dict})
Below, we’ll log the first 20 images in the first batch of the validation dataset along with the predicted and ground truth labels.
from pytorch_lightning.callbacks import Callback
class LogPredictionsCallback(Callback):
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
"""Called when the validation batch ends."""
# `outputs` comes from `LightningModule.validation_step`
# which corresponds to our model predictions in this case
# Let's log 20 sample image predictions from first batch
if batch_idx == 0:
n = 20
x, y = batch
images = [img for img in x[:n]]
captions = [f'Ground Truth: {y_i} - Prediction: {y_pred}' for y_i, y_pred in zip(y[:n], outputs[:n])]
# Option 1: log images with `WandbLogger.log_image`
wandb_logger.log_image(key='sample_images', images=images, caption=captions)

# Option 2: log predictions as a Table
columns = ['image', 'ground truth', 'prediction']
data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]
wandb_logger.log_table(key='sample_table', columns=columns, data=data)

log_predictions_callback = LogPredictionsCallback()

Let’s Train our Model

Ok, now that we’re set up, let’s actually get training here:
trainer = Trainer(
logger=wandb_logger, # W&B integration
callbacks=[log_predictions_callback, # logging of sample predictions
checkpoint_callback], # our model checkpoint callback
max_epochs=5) # number of epochs
trainer.fit(model, training_loader, validation_loader)
When we want to close our W&B run, we call wandb.finish() (note: this is mainly useful in notebooks as it’s called automatically in scripts).
wandb.finish()
Now we can monitor our loss, metrics, gradients, parameters and even sample predictions as our model trains! For the above, you might see something like this:

Run set
30


Conclusion

That just about wraps up our quick introduction to integrating PyTorch Lightning with W&B. As a reminder, you can always get more detail about the integration in our docs or check out the associated colab for a version of this report with executable code.
Thanks for reading!
Iterate on AI agents and models faster. Try Weights & Biases today.