Multi-GPU Training Using PyTorch Lightning

Execute multi GPU training using PyTorch Lightning and visualize GPU performance. Made by Ayush Thakur using Weights & Biases
Ayush Thakur

A GPU is the workhorse for most deep learning workflow. If you have used TensorFlow Keras you must have known that the same training script can be used to train a model using multi GPUs and even with TPU with minimal to no change.

In this report, we will see how we can make our PyTorch script accelerator agnostic, i.e, we can use the same PyTorch code organized using PyTorch Lightning and train using multiple GPUs across multiple devices.


This report is a part of my PyTorch Lightning series. Before you train your model on multiple GPUs make sure to check out the two reports to get you started with PyTorch Lightning:

PyTorch Lightning lets you decouple research from engineering. Making your PyTorch code train on multiple GPUs can be daunting if you are not experienced and a waste of time if you want to scale your research. PyTorch Lightning is more of a "style guide" that helps you organize your PyTorch code such that you do not have to write boilerplate code which also involves multi GPU training.

Common workflow with PyTorch Lightning

In this report, we will see how easy it is to train our model on multiple GPUs.

Multi-GPU Training

I have structured the PyTorch code for image classification on the Caltech-101 dataset using PyTorch Lightning. I have used my GCP account to train the classifier with two K80 GPUs with just one minor change in the Trainer.

How to train on multiple GPUs?

Keeping everything the same just pass gpus and accelerator argument to the PyTorch Lightning Trainer. I had access to two K80 GPUs thus gpus=2. I was using Jupyter Notebook for training thus accelerator='dp. Here dp stands for Data Parallel. We will soon go into the specifics but before that let's visualize the system metrics using Weights and Biases.

# Initialize a trainer
trainer = pl.Trainer(max_epochs=50,

Weights and Biases automatically captures metrics related to GPU. The media panels shown below are some of the most important metrics that we care about. You can also see the training and test metrics.