Skip to main content

How To Write Efficient Training Loops in PyTorch

In this tutorial, we cover how to write extremely memory- and compute-efficient training loops in PyTorch, complete with share code and interactive visualizations.
Created on July 8|Last edited on January 18

Over the past few weeks, we've looked at a number of ways using which we can improve the "efficiency" of PyTorch training loops. Here are some of the things we've looked at in the past:

In this article, we'll stitch all these ideas together and figure out a general template PyTorch training loop that is memory and compute efficient!

Table of Contents



Show Me the Code

Here's what a general training loop in PyTorch looks like. (I know we've looked at this in the past, but let's have one more look for perpetuity's sake):
optimizer = ...

for epoch in range(...):
for i, sample in enumerate(dataloader):
inputs, labels = sample
optimizer.zero_grad()

# Forward Pass
outputs = model(inputs)
# Compute Loss and Perform Back-propagation
loss = loss_fn(outputs, labels)
loss.backward()
# Update Optimizer
optimizer.step()
Well, let's see how we can improve this template:
  • Automatic Mixed Precision(AMP): to automatically cast your tensors to a smaller memory footprint.
  • Accumulation Steps: instead of updating after every batch, we rather store the gradients for some forward passes and then run backpropagation.
  • Gradient Scaling: scaling our gradients by some factor, so they aren't flushed to zero.
  • Garbage Collection: To free up memory and cache
Let's stitch all of these tips together into a better training loop.
optimizer = ...
scaler = torch.cuda.amp.GradScaler()
NUM_ACCUMULATION_STEPS = ...

for epoch in range(...):
for idx, sample in enumerate(dataloader):
inputs, labels = sample

# ⭐️⭐️ Automatic Tensor Casting
with torch.cuda.amp.autocast():
outputs = model(inputs)
loss = criterion(outputs, label)
scaler.scale(loss).backward() # ⭐️⭐️ Automatic Gradient Scaling
# Normalize the Gradients
loss = loss / NUM_ACCUMULATION_STEPS

# ⭐️⭐️ Gradient Accumulation
if ((idx + 1) % NUM_ACCUMULATION_STEPS == 0) or (idx + 1 == len(dataloader)):
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

# ⭐️⭐️ Garbage Collection
torch.cuda.empty_cache()
_ = gc.collect()

Summary

In this article, you saw how you could create a memory and compute efficient training loops in PyTorch using tips from previous reports.
To see the full suite of W&B features, please check out this short 5 minutes guide. If you want more reports covering the math and from-scratch code implementations, let us know in the comments below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental development topics like GPU Utilization and Saving Models.

Iterate on AI agents and models faster. Try Weights & Biases today.