Skip to main content

Distributed Training with Shared Mode

End-to-end example of training a model on a multi-node multi-GPU Kubernetes cluster in GKE using the Shared mode that allows consistent logging to the same run ID from multiple independent processes.
Created on February 7|Last edited on April 9

Training setup in GKE

Create a 2‑node cluster with each node having two NVIDIA T4 GPUs.
gcloud container clusters create my-gpu-cluster \
--zone=us-central1-b \
--num-nodes=2 \
--machine-type=n1-standard-4 \
--accelerator=type=nvidia-tesla-t4,count=2 \
--disk-size=200GB \
--disk-type=pd-ssd
Install the NVIDIA drivers and device plugin so that Kubernetes can schedule GPU workloads:
kubectl apply -f https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.14.0/nvidia-device-plugin.yml

Build and push the image to GCR.

It might be wise to build on an Ubuntu VM in GCP. Provision 100 GB SSD and install docker with
sudo apt update && sudo apt upgrade -y
sudo apt -get install docker.io
Install the necessary tools and authorize:
sudo apt install python3 python3-pip
sudo pip3 install gcloud
sudo usermod -a -G docker ${USER}
# echo your JSON key > creds.json
gcloud auth activate-service-account --key-file=creds.json
cat creds.json | docker login -u _json_key --password-stdin https://gcr.io
Dockerfile:
# Use a PyTorch 2 image with CUDA 12 and cuDNN 8 as a base.
FROM pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel

# Install additional system dependencies.
# RUN apt-get update && apt-get install -y --no-install-recommends \
# libopencv-dev && \
# rm -rf /var/lib/apt/lists/*

# Install wandb.
RUN pip install --upgrade pip && pip install wandb

# Copy the training script into the container.
RUN mkdir /ai
COPY train.py /ai/train.py

# Set the entrypoint.
ENTRYPOINT ["python", "/ai/train.py"]

# CMD ["--epochs", "10", "--batch_size", "128", "--lr", "0.001"]
Once you have docker and are authorized to upload to gcr.io, build the docker image and push it.
export PROJECT=<PROJECT>
# Build the Docker image locally.
docker build --platform linux/amd64 -t gcr.io/$PROJECT/ddp-training:latest .
# Push the image to GCR.
docker push gcr.io/$PROJECT/ddp-training:latest
Create a secret:
kubectl create secret generic wandb-secret --from-literal=WANDB_API_KEY=<MAH_API_KEY>
Apply the Kubernetes Deployment YAML:
kubectl apply -f training.yaml
kubectl get pods

Kubernetes Deployment spec

apiVersion: v1
kind: Service
metadata:
name: vit-ddp
labels:
app: vit-ddp
spec:
clusterIP: None # headless service so that the pods have stable network IDs
selector:
app: vit-ddp
ports:
- port: 29500
protocol: TCP
name: rendezvous
---
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: vit-ddp
labels:
app: vit-ddp
spec:
serviceName: vit-ddp
replicas: 2
selector:
matchLabels:
app: vit-ddp
template:
metadata:
labels:
app: vit-ddp
spec:
# restartPolicy: OnFailure
# restartPolicy: Never
containers:
- name: vit-trainer
image: gcr.io/<PROJECT>/ddp-training:latest
# Request 2 GPUs per pod.
resources:
limits:
cpu: 4
nvidia.com/gpu: 2
ephemeral-storage: "100Gi"
requests:
cpu: 2
ephemeral-storage: "50Gi"
env:
# Pass wandb credentials and a common run id.
- name: WANDB_API_KEY
valueFrom:
secretKeyRef:
name: wandb-secret
key: WANDB_API_KEY
- name: WANDB_RUN_ID
value: "liquid-tranquility-51"
- name: NCCL_DEBUG
value: "INFO"
- name: TORCH_DISTRIBUTED_DEBUG
value: "DETAIL"
# Use a bash command to extract the ordinal from the pod hostname.
command: ["/bin/bash", "-c"]
args:
- |
# Extract the node_rank from the pod hostname.
# Expecting hostname of the form: vit-ddp-0, vit-ddp-1, etc.
NODE_RANK=${HOSTNAME##*-}
echo "Starting torchrun with node_rank=${NODE_RANK}"
# For non-zero ranks, add a delay to allow rank 0 to initialize.
if [ "$NODE_RANK" -gt 0 ]; then
echo "Non-zero rank detected. Sleeping for 10 seconds to allow the master to be ready."
sleep 10
fi
torchrun \
--nnodes=2 \
--nproc_per_node=2 \
--node_rank=${NODE_RANK} \
--rdzv_backend=c10d \
--rdzv_endpoint=vit-ddp:29500 \
/ai/train.py --epochs 120 --batch_size 256 --lr 0.0007
volumeMounts:
# Mount a tmpfs volume to /dev/shm for increased shared memory
- name: dshm
mountPath: /dev/shm
volumes:
- name: dshm
emptyDir:
medium: Memory
sizeLimit: 20Gi # Adjust this size as needed


Training script

#!/usr/bin/env python
"""Distributed Training Script using PyTorch DDP with a SimpleCNN on FashionMNIST.

Each process prints the GPU index in use.
Logs misclassified examples every 20 epochs as wandb.Images.
Also logs the code from the current directory.
"""

import os
import argparse
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import torchvision
import torchvision.transforms as transforms

import wandb


def setup_distributed():
"""Initialize the distributed environment and set the GPU for this process."""
if not dist.is_initialized():
dist.init_process_group(backend="nccl", init_method="env://")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Print GPU index and name in use for debugging.
current_device = torch.cuda.current_device()
print(
f"[Process {dist.get_rank()}] Using GPU index: {current_device} - {torch.cuda.get_device_name(current_device)}"
)


def cleanup_distributed():
dist.destroy_process_group()


def get_dataloaders(batch_size, world_size, rank):
"""Download FashionMNIST and create distributed train/test dataloaders.

For a simpler model, we use the native image size (28x28) and basic normalization.
"""
transform_train = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
transform_test = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

train_dataset = torchvision.datasets.FashionMNIST(
root="./data", train=True, download=True, transform=transform_train
)
test_dataset = torchvision.datasets.FashionMNIST(
root="./data", train=False, download=True, transform=transform_test
)

# Create a DistributedSampler so that each process sees a unique subset.
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=world_size, rank=rank
)
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset, num_replicas=world_size, rank=rank
)

train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=4,
pin_memory=True,
)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=batch_size,
sampler=test_sampler,
num_workers=4,
pin_memory=True,
)

return train_loader, test_loader


# Define a simple CNN suitable for FashionMNIST.
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# FashionMNIST images are 28x28 grayscale images.
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.pool(x)
x = self.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x


def train(args):
setup_distributed()
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = dist.get_rank()
world_size = dist.get_world_size()

# Initialize wandb in shared mode.
wandb_run_id = os.environ.get("WANDB_RUN_ID", "default_run_id")
settings = wandb.Settings(
mode="shared",
x_stats_sampling_interval=1,
# GPU index to capture metrics from.
# In DDP, each process has a single GPU, but all GPUs on the node may be visible.
x_stats_gpu_device_ids=[local_rank],
x_label=f"rank-{global_rank}",
)
if global_rank != 0:
# Do not upload wandb files except console logs.
settings.x_primary = False
# Do not change the state of the run on run.finish().
settings.x_update_finish_state = False

run = wandb.init(
project="simple-cnn-ddp",
id=wandb_run_id,
config={
"epochs": args.epochs,
"batch_size": args.batch_size,
"learning_rate": args.lr,
"world_size": world_size,
},
settings=settings,
)

# Update the run metadata with the number of CPUs and GPUs in the cluster.
run._metadata.gpu_count = world_size
run._metadata.cpu_count = 8
run._metadata.cpu_count_logical = 16

# Log the source code from the current directory.
# Note: It is recommended to initialize a git repository in this directory for better version tracking,
# but run.log_code('.') will work even without a git repo.
run.log_code("/")

# Create the model, move to GPU, and wrap with DDP.
model = SimpleCNN().cuda()
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

train_loader, test_loader = get_dataloaders(
args.batch_size, world_size, global_rank
)

num_batches = len(train_loader)
for epoch in range(args.epochs):
model.train()
epoch_loss = 0.0
train_loader.sampler.set_epoch(epoch) # ensure data shuffling across epochs

start_time = time.time()
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs = inputs.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

epoch_loss += loss.item()

avg_loss = epoch_loss / num_batches

# Evaluate on test data.
model.eval()
correct = 0
total = 0
# For misclassification logging (only on rank 0 every 20 epochs)
misclassified_examples = []
with torch.no_grad():
for inputs, targets in test_loader:
inputs = inputs.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
# If this is an epoch for logging misclassified examples (and rank 0), collect a few.
if (
global_rank == 0
and epoch % 20 == 0
and epoch > 0
and len(misclassified_examples) < 10
):
# Iterate through the batch and store misclassified images.
for i in range(inputs.size(0)):
if predicted[i] != targets[i]:
caption = f"pred: {predicted[i].item()}, true: {targets[i].item()}"
# wandb.Image will handle the tensor appropriately (convert to CPU)
misclassified_examples.append(
wandb.Image(inputs[i].detach().cpu(), caption=caption)
)
if len(misclassified_examples) >= 10:
break

accuracy = 100.0 * correct / total
print(
f"Epoch [{epoch+1}/{args.epochs}] Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%"
)
if global_rank == 0:
log_data = {
"epoch": epoch,
"avg_loss": avg_loss,
"accuracy": accuracy,
"global_rank": global_rank,
}
if epoch % 20 == 0 and misclassified_examples:
log_data["misclassified_examples"] = misclassified_examples
run.log(log_data)

run.finish()
cleanup_distributed()


def parse_args():
parser = argparse.ArgumentParser(
description="DDP Training Example with SimpleCNN on FashionMNIST"
)
parser.add_argument("--epochs", type=int, default=10, help="number of total epochs")
parser.add_argument(
"--batch_size", type=int, default=128, help="batch size per process"
)
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
train(args)


Sample run

Correct hardware overview:

Logs from the different ranks / processes:

Log files get uploaded separately from each process:




Run: liquid-tranquility-51
1



Run: liquid-tranquility-51
1



Run: liquid-tranquility-51
1