PyTorch Monitoring in Five Lines
cifar10.py
import torch
import torchvision
import torchvision.transforms as transforms
import wandb
wandb.init()
config = wandb.config
config.batch_size = 20
config.lr = 0.001
config.momentum = 0.9
config.epochs = 10
config.hidden_nodes = 120
config.conv1_channels = 5
config.conv2_channels = 16
########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, config.conv1_channels, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(config.conv1_channels, config.conv2_channels, 5)
self.fc1 = nn.Linear(config.conv2_channels * 5 * 5, config.hidden_nodes)
self.fc2 = nn.Linear(config.hidden_nodes, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = net.to(device)
wandb.hook_torch(net)
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config.lr, momentum=config.momentum)
for epoch in range(config.epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
dataiter = iter(testloader)
images, labels = dataiter.next()
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)
example_images = [wandb.Image(image, caption=classes[predicted]) for image, predicted, label in zip(images, predicted, labels)]
# Let us look at how the network performs on the whole dataset.
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_acc = 100 * correct / total
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
class_acc = {}
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
class_acc["Accuracy of %5s" % (classes[i])] = 100 * class_correct[i] / class_total[i]
wandb.log(class_acc, commit=False)
wandb.log({"Examples": example_images,
"Test Acc": test_acc,
"Loss": running_loss})
I love pytorch and I love tracking my experiments. It’s possible to use tensorboard with pytorch but it can feel a little clunky. We recently added a feature to make it dead simple to monitor your pytorch models with wandb!
I started with the PyTorch cifar10 tutorial. This tutorial is fantastic but it uses matplotlib to show the images which can be annoying on a remote server, it doesn’t plot the accuracy or loss curves and it doesn’t let me inspect the gradients of the layers. Let’s fix all that with just a couple lines of code!
At the top I add the lines:
import wandb
wandb.init()
This starts a wandb process that tracks the input hyperparameters and lets me save metrics and files. It also saves the stdout, stderr and tracks my GPU usage and other system metrics automatically.
System Metrics
Here's an example of the system metrics automatically captured by W&B.
Now I can add a log command at the end of each epoch and easily see how my network is performing on each class:
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
class_acc["Accuracy of %5s" % (classes[i])] = 100 * class_correct[i] / class_total[i]
wandb.log(class_acc)