How to Integrate PyTorch Lightning with Weights & Biases
A quick tutorial on integrating Lightning with W&B, complete with executable code and interactive visualizations
Created on March 14|Last edited on November 30
Comment
Introduction
In this tutorial, we’re going to run you through a few quick steps to integrate Weights & Biases with PyTorch Lightning. And while it takes just a couple lines of code to get started...
from pytorch_lightning.loggers import WandbLoggerfrom pytorch_lightning import Trainerwandb_logger = WandbLogger()trainer = Trainer(logger=wandb_logger)
...the integration allows you to not only train, monitor, and reproduce your models but also:
- log your configuration parameters
- log your losses and metrics
- keep track of your code
- log your system metrics (GPU, CPU, memory, temperature, etc)
- visualize your data in customizable tables
- share and collaborate across your team & organization
And that's just the tip of the iceberg. Still, with that said, this tutorial is certainly not meant to be exhaustive but rather get you started using W&B with PyTorch Lightning. We'll be using some simple MNIST experiments for our backbone here and, if you'd rather read this or follow along in a colab with executable code, we've got you covered there too.
Lastly, before we begin in early, if you’re new to either W&B or PyTorch Lightning, here’s a quick explainer of both:
Weights & Biases is the premier developer-first MLOps platform. With W&B, you can save everything you need to debug, compare and reproduce your models — architecture, hyperparameters, git commits, model weights, GPU usage, and even datasets and predictions.
Pytorch Lightning, meanwhile, is a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training, 16-bit precision or gradient accumulation.
And with all that preamble out of the way, let’s get started in earnest, shall we?
Installing & Setting Up W&B
First, let's install our Lightning integration
!pip install -q pytorch-lightning wandb
Then take a moment to double-check that our experiments will be associated with our own personal accounts:
import wandbwandb.login()
Setting up Our Dataloader
For the purposes of our tutorial, we'll be using a vanilla PyTorch dataloader and the canonical MNIST dataset. This code brings in that dataset and lets us transform it for our experiments:
from torchvision.datasets import MNISTfrom torchvision import transformsfrom torch.utils.data import DataLoader, random_splittransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])dataset = MNIST(root="./MNIST", download=True, transform=transform)training_set, validation_set = random_split(dataset, [55000, 5000])
--
training_loader = DataLoader(training_set, batch_size=64, shuffle=True)validation_loader = DataLoader(validation_set, batch_size=64)
Defining Our Model
Ok, so now that we’ve installed our dependencies and brought in our data, we’re ready to train our model and visualize our work on W&B. The code below is going to define our models, but we’ll show you how to save your model checkpoints, track experiments, and more below. First though:
import torchfrom torch.nn import Linear, CrossEntropyLoss, functional as Ffrom torch.optim import Adamfrom torchmetrics.functional import accuracyfrom pytorch_lightning import LightningModuleclass MNIST_LitModule(LightningModule):def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):'''method used to define our model parameters'''super().__init__()# mnist images are (1, 28, 28) (channels, width, height)self.layer_1 = Linear(28 * 28, n_layer_1)self.layer_2 = Linear(n_layer_1, n_layer_2)self.layer_3 = Linear(n_layer_2, n_classes)# lossself.loss = CrossEntropyLoss()# optimizer parametersself.lr = lr# save hyper-parameters to self.hparams (auto-logged by W&B)self.save_hyperparameters()def forward(self, x):'''method used for inference input -> output'''batch_size, channels, width, height = x.size()# (b, 1, 28, 28) -> (b, 1*28*28)x = x.view(batch_size, -1)# let's do 3 x (linear + relu)x = self.layer_1(x)x = F.relu(x)x = self.layer_2(x)x = F.relu(x)x = self.layer_3(x)return xdef training_step(self, batch, batch_idx):'''needs to return a loss from a single batch'''_, loss, acc = self._get_preds_loss_accuracy(batch)# Log loss and metricself.log('train_loss', loss)self.log('train_accuracy', acc)return lossdef validation_step(self, batch, batch_idx):'''used for logging metrics'''preds, loss, acc = self._get_preds_loss_accuracy(batch)# Log loss and metricself.log('val_loss', loss)self.log('val_accuracy', acc)# Let's return preds to use it in a custom callbackreturn predsdef test_step(self, batch, batch_idx):'''used for logging metrics'''_, loss, acc = self._get_preds_loss_accuracy(batch)# Log loss and metricself.log('test_loss', loss)self.log('test_accuracy', acc)def configure_optimizers(self):'''defines model optimizer'''return Adam(self.parameters(), lr=self.lr)def _get_preds_loss_accuracy(self, batch):'''convenience function since train/valid/test steps are similar'''x, y = batchlogits = self(x)preds = torch.argmax(logits, dim=1)loss = self.loss(logits, y)acc = accuracy(preds, y)return preds, loss, acc
model = MNIST_LitModule(n_layer_1=128, n_layer_2=128)
And with that, our model is ready! A quick note on the code above! You can:
- Call self.save_hyperparameters() in __init__ to automatically log your hyperparameters to W&B
- Call self.log in training_step and validation_step to log the metrics
Saving Model Checkpoint
Saving model checkpoints is as easy as you’d expect (i.e. just a couple lines of code). Note: the ModelCheckpoint callback is required along with the WandbLogger argument to log model checkpoints to W&B.
from pytorch_lightning.callbacks import ModelCheckpointcheckpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')
Tracking Experiments with WandbLogger
PyTorch Lightning has a WandbLogger that lets you easily log your experiments with Weights & Biases. Just pass it to your Trainer to log to W&B. You can check out the WandbLogger docs for all parameters.
Note: to log the metrics to a specific W&B Team, pass your team name to the entity argument in WandbLogger
Here are a few arguments that you might need (though you can click the docs above for more)
- Log models: WandbLogger(... ,log_model='all') or WandbLogger(... ,log_model=True)
- Set custom run names: WandbLogger(... ,name='my_run_name')
- Organize runs: WandbLogger(... ,project='my_project')
- Log histograms of gradients & parameters: WandbLogger.watch(model)
- Log hyperparameters: Call self.save_hyperparameters() within LightningModule.__init__()
- Log custom objects (like images, videos, or molecules): Use WandbLogger.log_text, WandbLogger.log_image and WandbLogger.log_table
from pytorch_lightning.loggers import WandbLoggerfrom pytorch_lightning import Trainerwandb_logger = WandbLogger(project='MNIST', # group runs in "MNIST" projectlog_model='all') # log all new checkpoints during training
Logging Images, Text, and More with WandbLogger
Pytorch Lightning is extensible through its callback system. That means we can create a custom callback to automatically log sample predictions during validation.
WandbLogger provides convenient media logging functions:
- WandbLogger.log_text for text data
- WandbLogger.log_image for images
An alternative to self.log in the model class is directly using wandb.log({dict}) or trainer.logger.experiment.log({dict})
Below, we’ll log the first 20 images in the first batch of the validation dataset along with the predicted and ground truth labels.
from pytorch_lightning.callbacks import Callbackclass LogPredictionsCallback(Callback):def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):"""Called when the validation batch ends."""# `outputs` comes from `LightningModule.validation_step`# which corresponds to our model predictions in this case# Let's log 20 sample image predictions from first batchif batch_idx == 0:n = 20x, y = batchimages = [img for img in x[:n]]captions = [f'Ground Truth: {y_i} - Prediction: {y_pred}' for y_i, y_pred in zip(y[:n], outputs[:n])]# Option 1: log images with `WandbLogger.log_image`wandb_logger.log_image(key='sample_images', images=images, caption=captions)# Option 2: log predictions as a Tablecolumns = ['image', 'ground truth', 'prediction']data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]wandb_logger.log_table(key='sample_table', columns=columns, data=data)log_predictions_callback = LogPredictionsCallback()
Let’s Train our Model
Ok, now that we’re set up, let’s actually get training here:
trainer = Trainer(logger=wandb_logger, # W&B integrationcallbacks=[log_predictions_callback, # logging of sample predictionscheckpoint_callback], # our model checkpoint callbackmax_epochs=5) # number of epochs
trainer.fit(model, training_loader, validation_loader)
When we want to close our W&B run, we call wandb.finish() (note: this is mainly useful in notebooks as it’s called automatically in scripts).
wandb.finish()
Now we can monitor our loss, metrics, gradients, parameters and even sample predictions as our model trains! For the above, you might see something like this:
Run set
30
Conclusion
That just about wraps up our quick introduction to integrating PyTorch Lightning with W&B. As a reminder, you can always get more detail about the integration in our docs or check out the associated colab for a version of this report with executable code.
Thanks for reading!
Add a comment
Tags: Articles, PyTorch Lightning
Iterate on AI agents and models faster. Try Weights & Biases today.