Skip to main content

How to Evaluate GANs using Frechet Inception Distance (FID)

In this article, we will briefly discuss the details of GAN evaluation and how to implement the Frechet Inception Distance (FID) evaluation pipeline.
Created on January 15|Last edited on December 2
If you are familiar with Generative Adversarial Networks (GANs) and have trained one, you must have wondered which model checkpoint to use for inference. In image classification, you use the model checkpoint, which gives the best validation accuracy.
However, that is not the case for GANs. Training a GAN is hard and unstable and has issues like mode collapse, etc. If you want to learn more about generative models in general, check out Towards Deep Generative Modeling with W&B.

Try to train your own simple GAN \rightarrow

The training metrics are shown in the media panels below. The question is whether, by looking at the training metrics, you can decide on the model checkpoint that you will use to generate new images?

Table of Contents




Run set
7

A naive way to evaluate GAN is to "babysit" the model training process, i.e, look manually look at the images generated using model checkpoints. You can use a callback to log a batch of generated images at every n epochs.
If you have checked out the colab notebook linked above, I have implemented a custom Keras callback to log the generated images.
class GeneratedImageLogger(tf.keras.callbacks.Callback):
def __init__(self, noise_dim, batch_size=32):
super(GeneratedImageLogger, self).__init__()
self.noise = tf.random.normal([batch_size, noise_dim])

def on_epoch_end(self, logs, epoch):
generated_image = generator(self.noise, training=False)

wandb.log({"gen_images": [wandb.Image(image)
for image in generated_image]})
The result of using the callback is shown below. Click on the ⚙️ icon in the media panel shown below and change the number of steps to visualize the images generated at every epoch.

Run set
1



GAN Evaluation




-> Figure 1: (Source) <-
  • In supervised image classification, evaluation is straightforward. We have to compare the predicted output to the actual output.
  • However, with a GAN, you pass in some random noise to get this fake(generated) image. We want this generated image to look as real as possible. So, how exactly can you quantify the realism of this generated image? Or how exactly can you evaluate GAN?
We will start by setting two simple properties for our evaluation metric:
  • Fidelity: We want our GAN to generate high quality images.
  • Diversity: Our GAN should generate images that are inherent in the training dataset.
Thus our evaluation metric should evaluate against both the properties. However, comparing images on fidelity and diversity can be challenging because what exactly should you be comparing? Two approaches to compare images that are widely used in computer vision are:
  • Pixel Distance: This is a naive distance measure where we subtract two images' pixel values. However, this is not a reliable metric.
  • Feature Distance: We use a pre-trained image classification model and use the activation of an intermediate layer. This vector is the high-level representation of the image. Computing a distance metric with such representation gives a stable and reliable metric.
Now that we have some grounds covered, we will quickly see how we can use Frechet Inception Distance(FID) to evaluate GANs.

Frechet Inception Distance(FID)

This is one of the most popular metrics for measuring the feature distance between the real and the generated images. Frechet Distance is a measure of similarity between curves that takes into account the location and ordering of the points along the curves. It can be used to measure the distance between two distributions as well.

Evaluate your GAN using FID \rightarrow

Mathematically, Frechet Distance is used to compute the distance between two "multivariate" normal distribution. For a "univariate" normal distribution Frechet Distance is given as,
d(X,Y)=(μXμY)2+(σXσY)2d(X, Y) = (μ_X - μ_Y)^2 + (σ_X - σ_Y)^2
Where μμ and σσ are the mean and standard deviation of the normal distributions, XX, and YY are two normal distributions.
In the context of computer vision, especially GAN evaluation, we use feature distance as described above. We use the Inception V3 model pre-trained on the Imagenet dataset. The use of activations from the Inception V3 model to summarize each image gives the score its name of “Frechet Inception Distance.”
This activation is taken from the penultimate pooling layer(Global Average Pooling if you are using TensorFlow). We assume the output vector of shape (2048, ) to be approximated by "multivariate" normal distribution.
The Frechet Inception Distance for "multivariate" normal distribution is given by,
FID=μXμY2Tr(X+Y2XY)FID = ||μ_X - μ_Y||^2 - Tr(\sum_X + \sum_Y - 2\sqrt{\sum_X\sum_Y})
where XX and YY are the real and fake embeddings(activation from the Inception model) assumed to be two multivariate normal distributions. μXμ_X and μYμ_Y are the magnitudes of the vector XX and YY. TrTr is the trace of the matrix and X\sum_ X and Y\sum_Y are the covariance matrix of the vectors.


Implement FID using TensorFlow

In this section, we will look at the implementation of the GAN evaluation pipeline using the FID score. We will look at the key components of the same.

To evaluate your GAN using FID \rightarrow

  1. To use the pre-trained Inception V3 model:
    inception_model = tf.keras.applications.InceptionV3(include_top=False,
    weights="imagenet",
    pooling='avg')
  2. Compute the embeddings for real images and generated images. Note that the authors of GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium have recommended using a minimum sample size of 10,000 to calculate the FID otherwise the true FID of the generator is underestimated.
    def compute_embeddings(dataloader, count):
    image_embeddings = []
    
    for _ in tqdm(range(count)):
    images = next(iter(dataloader))
    embeddings = inception_model.predict(images)
    
    image_embeddings.extend(embeddings)
    
    return np.array(image_embeddings)
    
    count = math.ceil(10000/BATCH_SIZE)
    
    # compute embeddings for real images
    real_image_embeddings = compute_embeddings(trainloader, count)
    
    # compute embeddings for generated images
    generated_image_embeddings = compute_embeddings(genloader, count)
    
    real_image_embeddings.shape, generated_image_embeddings.shape
  3. Here trainloader and genloader are tf.data dataset. Check out the colab notebook for implementation details.
  4. With real and generated images embeddings, we will compute the FID score.
    def calculate_fid(real_embeddings, generated_embeddings):
    # calculate mean and covariance statistics
    mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(generated_embeddings, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = linalg.sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if np.iscomplexobj(covmean):
    covmean = covmean.real
    # calculate score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid
    
    fid = calculate_fid(real_image_embeddings, generated_image_embeddings)
Now that you know how to compute the FID score. Let's compute the FID score for GAN checkpoints. I have used W&B Artifacts to store the checkpoints at every 5 epochs. Using the functions described above I have computed the FID scores for each checkpoint. The result is shown in the media panel below.



Run set
1


Observations

  • The FID score decreases with the better model checkpoint. One can pick a model checkpoint that generates a low FID score for inference.
  • The FID score is relatively high because the Inception model is trained on Imagenet which constitutes natural images while our GAN is trained on the FashionMNIST dataset.

Shortcomings of FID

  • It uses a pre-trained Inception model, which may not capture all features. This can result in a high FID score as in the above case.
  • It needs a large sample size. The minimum recommended sample size is 10,000. For a high-resolution image(say 512x512 pixels) this can be computationally expensive and slow to run.
  • Limited statistics(mean and covariance) are used to compute the FID score.

Kritz
Kritz •  
cannot access the colab notebook!
Reply
Junseok
Junseok •  
Thank you so much!! but i want to see google colab can i see that?
Reply
Long Rao
Long Rao •  
The furmula of FID is wrong,please check it!!!
Reply
Mert Caglar
Mert Caglar •  
Thank you for your detailed explanation :)
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.