Skip to main content

PyTorch Monitoring in Five Lines

Created on November 12|Last edited on November 12


cifar10.py

import torch
import torchvision
import torchvision.transforms as transforms
import wandb

wandb.init()

config = wandb.config
config.batch_size = 20
config.lr = 0.001
config.momentum = 0.9
config.epochs = 10
config.hidden_nodes = 120
config.conv1_channels = 5
config.conv2_channels = 16


########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, config.conv1_channels, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(config.conv1_channels, config.conv2_channels, 5)
        self.fc1 = nn.Linear(config.conv2_channels * 5 * 5, config.hidden_nodes)
        self.fc2 = nn.Linear(config.hidden_nodes, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = net.to(device)


wandb.hook_torch(net)

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config.lr, momentum=config.momentum)

for epoch in range(config.epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

    dataiter = iter(testloader)
    images, labels = dataiter.next()
    images, labels = images.to(device), labels.to(device)
        
    outputs = net(images)
    _, predicted = torch.max(outputs, 1)

    example_images = [wandb.Image(image, caption=classes[predicted]) for image, predicted, label in zip(images, predicted, labels)]

    # Let us look at how the network performs on the whole dataset.

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_acc = 100 * correct / total

    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(4):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    class_acc = {}
    for i in range(10):
        print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))
        class_acc["Accuracy of %5s" % (classes[i])] = 100 * class_correct[i] / class_total[i]

    wandb.log(class_acc, commit=False)
    wandb.log({"Examples": example_images, 
                "Test Acc": test_acc, 
                "Loss": running_loss})

I love pytorch and I love tracking my experiments. It’s possible to use tensorboard with pytorch but it can feel a little clunky. We recently added a feature to make it dead simple to monitor your pytorch models with wandb!

I started with the PyTorch cifar10 tutorial. This tutorial is fantastic but it uses matplotlib to show the images which can be annoying on a remote server, it doesn’t plot the accuracy or loss curves and it doesn’t let me inspect the gradients of the layers. Let’s fix all that with just a couple lines of code!

At the top I add the lines:

import wandb
wandb.init()

This starts a wandb process that tracks the input hyperparameters and lets me save metrics and files. It also saves the stdout, stderr and tracks my GPU usage and other system metrics automatically.

020406080Step354045505560

System Metrics

Here's an example of the system metrics automatically captured by W&B.

system metrics

Now I can add a log command at the end of each epoch and easily see how my network is performing on each class:

   for i in range(10):
        print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))
        class_acc["Accuracy of %5s" % (classes[i])] = 100 * class_correct[i] / class_total[i]

    wandb.log(class_acc)
020406080Step304050607080
All Runs
12