How to Evaluate GANs using Frechet Inception Distance (FID)

In this report, we will briefly discuss the gotchas of GAN evaluation and how to implement FID evaluation pipeline. Made by Ayush Thakur using W&B
Ayush Thakur

🌟 Introduction

If you are familiar with Generative Adversarial Networks(GAN) and have trained one, you must have wondered which model checkpoint to use for inference. In image classification, you use the model checkpoint, which gives the best validation accuracy.

However, that is not the case for GANs. Training a GAN is hard and unstable and has issues like mode collapse, etc. If you want to learn more about generative models in general, check out Towards Deep Generative Modeling with W&B.

Try to train your own simple GAN $\rightarrow$

The training metrics are shown in the media panels below. The question is whether, by looking at the training metrics, you can decide on the model checkpoint that you will use to generate new images?

A naive way to evaluate GAN is to "babysit" the model training process, i.e, look manually look at the images generated using model checkpoints. You can use a callback to log a batch of generated images at every n epochs.

If you have checked out the colab notebook linked above, I have implemented a custom Keras callback to log the generated images.

class GeneratedImageLogger(tf.keras.callbacks.Callback):
    def __init__(self, noise_dim, batch_size=32):
      super(GeneratedImageLogger, self).__init__()
      self.noise = tf.random.normal([batch_size, noise_dim])

    def on_epoch_end(self, logs, epoch):
      generated_image = generator(self.noise, training=False)

      wandb.log({"gen_images": [wandb.Image(image)
                            for image in generated_image]})

The result of using the callback is shown below. Click on the :gear: icon in the media panel shown below and change the number of steps to visualize the images generated at every epoch.

GAN Evaluation

image.png

-> Figure 1: (Source) <-

We will start by setting two simple properties for our evaluation metric:

Thus our evaluation metric should evaluate against both the properties. However, comparing images on fidelity and diversity can be challenging because what exactly should you be comparing? Two approaches to compare images that are widely used in computer vision are:

Now that we have some grounds covered, we will quickly see how we can use Frechet Inception Distance(FID) to evaluate GANs.

❄️ Frechet Inception Distance(FID)

This is one of the most popular metrics for measuring the feature distance between the real and the generated images. Frechet Distance is a measure of similarity between curves that takes into account the location and ordering of the points along the curves. It can be used to measure the distance between two distributions as well.

Evaluate your GAN using FID $\rightarrow$

Mathematically, Frechet Distance is used to compute the distance between two "multivariate" normal distribution. For a "univariate" normal distribution Frechet Distance is given as,

$d(X, Y) = (μ_X - μ_Y)^2 + (σ_X - σ_Y)^2$

Where $μ$ and $σ$ are the mean and standard deviation of the normal distributions, $X$, and $Y$ are two normal distributions.

In the context of computer vision, especially GAN evaluation, we use feature distance as described above. We use the Inception V3 model pre-trained on the Imagenet dataset. **The use of activations from the Inception V3 model to summarize each image gives the score its name of “Frechet Inception Distance.” **

This activation is taken from the penultimate pooling layer(Global Average Pooling if you are using TensorFlow). We assume the output vector of shape (2048, ) to be approximated by "multivariate" normal distribution.

The Frechet Inception Distance for "multivariate" normal distribution is given by,

$FID = ||μ_X - μ_Y||^2 - Tr(\sum_X + \sum_Y - 2\sqrt{\sum_X\sum_Y})$

where $X$ and $Y$ are the real and fake embeddings(activation from the Inception model) assumed to be two multivariate normal distributions. $μ_X$ and $μ_Y$ are the magnitudes of the vector $X$ and $Y$. $Tr$ is the trace of the matrix and $\sum_ X$ and $\sum_Y$ are the covariance matrix of the vectors.

🐦 Implement FID using TensorFlow

In this section, we will look at the implementation of the GAN evaluation pipeline using the FID score. We will look at the key components of the same.

To evaluate your GAN using FID $\rightarrow$

  1. To use the pre-trained Inception V3 model:

    inception_model = tf.keras.applications.InceptionV3(include_top=False, 
                                  weights="imagenet", 
                                  pooling='avg')
    
  2. Compute the embeddings for real images and generated images. Note that the authors of GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium have recommended using a minimum sample size of 10,000 to calculate the FID otherwise the true FID of the generator is underestimated.

    def compute_embeddings(dataloader, count):
        image_embeddings = []
    
        for _ in tqdm(range(count)):
            images = next(iter(dataloader))
            embeddings = inception_model.predict(images)
    
            image_embeddings.extend(embeddings)
    
        return np.array(image_embeddings)
    
    count = math.ceil(10000/BATCH_SIZE)
    
    # compute embeddings for real images
    real_image_embeddings = compute_embeddings(trainloader, count)
    
    

compute embeddings for generated images

generated_image_embeddings = compute_embeddings(genloader, count)

real_image_embeddings.shape, generated_image_embeddings.shape ``` Here trainloader and genloader are tf.data dataset. Check out the colab notebook for implementation details.

  1. With real and generated images embeddings, we will compute the FID score.

     def calculate_fid(real_embeddings, generated_embeddings):
         # calculate mean and covariance statistics
         mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
         mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(generated_embeddings,  rowvar=False)
         # calculate sum squared difference between means
        ssdiff = np.sum((mu1 - mu2)**2.0)
        # calculate sqrt of product between cov
        covmean = linalg.sqrtm(sigma1.dot(sigma2))
        # check and correct imaginary numbers from sqrt
        if np.iscomplexobj(covmean):
           covmean = covmean.real
         # calculate score
         fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
         return fid
    
     fid = calculate_fid(real_image_embeddings, generated_image_embeddings)
    

Now that you know how to compute the FID score. Let's compute the FID score for GAN checkpoints. I have used W&B Artifacts to store the checkpoints at every 5 epochs. Using the functions described above I have computed the FID scores for each checkpoint. The result is shown in the media panel below.

Observations

🐸 Shortcomings of FID