Skip to main content

How to Implement Deep Convolutional Generative Adversarial Networks (DCGANs) in PyTorch

A short tutorial about implementing Deep Convolutional Generative Adversarial Networks in PyTorch, with a Colab to help you follow along.
Created on April 4|Last edited on April 8

Table of Contents (click to expand)

Introduction

This post assumes a basic understanding of Generative Adversarial Networks.
💡
In this report, we'll walk through how you can implement a Deep Convolutional Generative Adversarial Networks using the PyTorch Framework.
First introduced in the paper "Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks" by Alec Radford, Luke Metz and Soumith Chintala, it was a simple extension of the original GAN paper which exclusively used convolutional blocks as the core components of the discriminator and the generator.
Lastly, before jumping in, if you'd like to follow along with this piece in a Colab with executable code, you can find that right here:






The Implementation

The strategy to train any GAN (at a high level) is quite simple.
First have the generator (in this case a CNN) create some images, calculate the generator's loss and perform backpropagation. Then we feed these generated images and real images to the discriminator (again a CNN), combine the losses (from the real images of the dataset and the generated images from the generator) and then perform backpropagation.
In this way we're training two networks: one to generate images (the generator) and another which learns to generated between real images and the ones generated by our other network. Both the networks have their own optimizers and are usually evaluated using the same loss function.
In this case we use the Cross Entropy Loss for both the Generator and the Discriminator. For more details, refer to this post.
💡
Let's see an example training loop in PyTorch:-
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Optimizers
optimizer_G, optimizer_D = ..., ...

for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):

# Some Ground Truth
valid = Tensor(...)
fake = Tensor(...)

# --------- #
# Generator #
# --------- #

gen_optimizer.zero_grad()
gen_imgs = generator(z) # Generate a batch of images
gen_loss = adversarial_loss(discriminator(gen_imgs), valid) # Generator Loss
gen_loss.backward() # Backpropagation
gen_optimizer.step()

# ------------- #
# Discriminator #
# ------------- #

disc_optimizer.zero_grad()
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
disc_loss = (real_loss + fake_loss) / 2
disc_loss.backward()
disc_optimizer.step()

# ...
wandb.log({"Discriminator Loss": disc_loss.item(), "Generator Loss": gen_loss.item()})




The Results

Now, that you've seen how to implement a network and the training strategy, let's see how Weights & Biases allows for us to easily visualize important metrics and compare them using Panels.
For example, here's a quick comparison of the Discriminator loss grouped by the latent dimension, you'll find linked in the Colab above:

Run set
3

As we can see from the plots, a latent dimension of 100 happens to be the one with the lowest discriminator loss probably because it allows for greater learning, you can also try hyperparameter tuning by changing the batch size or the image size.
Weights & Biases Sweeps makes this incredibly easy by automatically running your pipeline using an agent. For more details please refer to our Sweeps Quickstart Guide.
If you'd like to try this yourself, here's the Colab to do so:




Summary

In this article, you saw how you can implement Deep Convolutional Generative Adversarial Networks using the PyTorch Framework and how the use of Weights and Biases allows you to easily visualize important metrics. To see the full suite of W&B features please check out this short 5 minutes guide.
If you want more reports covering the math and "from-scratch" code implementations let us know in the comments down below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental development topics like GPU Utilization and Saving Models.

Iterate on AI agents and models faster. Try Weights & Biases today.