How to Evaluate GANs using Frechet Inception Distance (FID)
In this article, we will briefly discuss the details of GAN evaluation and how to implement the Frechet Inception Distance (FID) evaluation pipeline.
Created on January 15|Last edited on December 2
Comment
If you are familiar with Generative Adversarial Networks (GANs) and have trained one, you must have wondered which model checkpoint to use for inference. In image classification, you use the model checkpoint, which gives the best validation accuracy.
However, that is not the case for GANs. Training a GAN is hard and unstable and has issues like mode collapse, etc. If you want to learn more about generative models in general, check out Towards Deep Generative Modeling with W&B.
Try to train your own simple GAN
The training metrics are shown in the media panels below. The question is whether, by looking at the training metrics, you can decide on the model checkpoint that you will use to generate new images?
Table of Contents
Run set
7
A naive way to evaluate GAN is to "babysit" the model training process, i.e, look manually look at the images generated using model checkpoints. You can use a callback to log a batch of generated images at every n epochs.
If you have checked out the colab notebook linked above, I have implemented a custom Keras callback to log the generated images.
class GeneratedImageLogger(tf.keras.callbacks.Callback):def __init__(self, noise_dim, batch_size=32):super(GeneratedImageLogger, self).__init__()self.noise = tf.random.normal([batch_size, noise_dim])def on_epoch_end(self, logs, epoch):generated_image = generator(self.noise, training=False)wandb.log({"gen_images": [wandb.Image(image)for image in generated_image]})
The result of using the callback is shown below. Click on the ⚙️ icon in the media panel shown below and change the number of steps to visualize the images generated at every epoch.
Run set
1
GAN Evaluation

- In supervised image classification, evaluation is straightforward. We have to compare the predicted output to the actual output.
- However, with a GAN, you pass in some random noise to get this fake(generated) image. We want this generated image to look as real as possible. So, how exactly can you quantify the realism of this generated image? Or how exactly can you evaluate GAN?
We will start by setting two simple properties for our evaluation metric:
- Fidelity: We want our GAN to generate high quality images.
- Diversity: Our GAN should generate images that are inherent in the training dataset.
Thus our evaluation metric should evaluate against both the properties. However, comparing images on fidelity and diversity can be challenging because what exactly should you be comparing? Two approaches to compare images that are widely used in computer vision are:
- Pixel Distance: This is a naive distance measure where we subtract two images' pixel values. However, this is not a reliable metric.
- Feature Distance: We use a pre-trained image classification model and use the activation of an intermediate layer. This vector is the high-level representation of the image. Computing a distance metric with such representation gives a stable and reliable metric.
Now that we have some grounds covered, we will quickly see how we can use Frechet Inception Distance(FID) to evaluate GANs.
Frechet Inception Distance(FID)
This is one of the most popular metrics for measuring the feature distance between the real and the generated images. Frechet Distance is a measure of similarity between curves that takes into account the location and ordering of the points along the curves. It can be used to measure the distance between two distributions as well.
Evaluate your GAN using FID
Mathematically, Frechet Distance is used to compute the distance between two "multivariate" normal distribution. For a "univariate" normal distribution Frechet Distance is given as,
Where and are the mean and standard deviation of the normal distributions, , and are two normal distributions.
In the context of computer vision, especially GAN evaluation, we use feature distance as described above. We use the Inception V3 model pre-trained on the Imagenet dataset. The use of activations from the Inception V3 model to summarize each image gives the score its name of “Frechet Inception Distance.”
This activation is taken from the penultimate pooling layer(Global Average Pooling if you are using TensorFlow). We assume the output vector of shape (2048, ) to be approximated by "multivariate" normal distribution.
The Frechet Inception Distance for "multivariate" normal distribution is given by,
where and are the real and fake embeddings(activation from the Inception model) assumed to be two multivariate normal distributions. and are the magnitudes of the vector and . is the trace of the matrix and and are the covariance matrix of the vectors.
Implement FID using TensorFlow
In this section, we will look at the implementation of the GAN evaluation pipeline using the FID score. We will look at the key components of the same.
To evaluate your GAN using FID
- To use the pre-trained Inception V3 model:inception_model = tf.keras.applications.InceptionV3(include_top=False,weights="imagenet",pooling='avg')
- Compute the embeddings for real images and generated images. Note that the authors of GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium have recommended using a minimum sample size of 10,000 to calculate the FID otherwise the true FID of the generator is underestimated.def compute_embeddings(dataloader, count):image_embeddings = []for _ in tqdm(range(count)):images = next(iter(dataloader))embeddings = inception_model.predict(images)image_embeddings.extend(embeddings)return np.array(image_embeddings)count = math.ceil(10000/BATCH_SIZE)# compute embeddings for real imagesreal_image_embeddings = compute_embeddings(trainloader, count)# compute embeddings for generated imagesgenerated_image_embeddings = compute_embeddings(genloader, count)real_image_embeddings.shape, generated_image_embeddings.shape
- Here trainloader and genloader are tf.data dataset. Check out the colab notebook for implementation details.
- With real and generated images embeddings, we will compute the FID score.def calculate_fid(real_embeddings, generated_embeddings):# calculate mean and covariance statisticsmu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(generated_embeddings, rowvar=False)# calculate sum squared difference between meansssdiff = np.sum((mu1 - mu2)**2.0)# calculate sqrt of product between covcovmean = linalg.sqrtm(sigma1.dot(sigma2))# check and correct imaginary numbers from sqrtif np.iscomplexobj(covmean):covmean = covmean.real# calculate scorefid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)return fidfid = calculate_fid(real_image_embeddings, generated_image_embeddings)
Now that you know how to compute the FID score. Let's compute the FID score for GAN checkpoints. I have used W&B Artifacts to store the checkpoints at every 5 epochs. Using the functions described above I have computed the FID scores for each checkpoint. The result is shown in the media panel below.
Run set
1
Observations
- The FID score decreases with the better model checkpoint. One can pick a model checkpoint that generates a low FID score for inference.
- The FID score is relatively high because the Inception model is trained on Imagenet which constitutes natural images while our GAN is trained on the FashionMNIST dataset.
Shortcomings of FID
- It uses a pre-trained Inception model, which may not capture all features. This can result in a high FID score as in the above case.
- It needs a large sample size. The minimum recommended sample size is 10,000. For a high-resolution image(say 512x512 pixels) this can be computationally expensive and slow to run.
- Limited statistics(mean and covariance) are used to compute the FID score.
Add a comment
cannot access the colab notebook!
Reply
Thank you so much!!
but i want to see google colab
can i see that?
Reply
The furmula of FID is wrong,please check it!!!
Reply
Thank you for your detailed explanation :)
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.