如何使用Frechet Inception Distance(FID)评估GANs模型?
🌟 导言
如果您熟悉Generative Adversarial Networks(GAN)并已对其进行了训练,您一定想知道要使用哪个模型检查点(model checkpoint)进行推理。 在图像分类中,您可以使用模型检查点来得到最佳的验证准确性。
但是,GAN并非如此。 训练GAN既困难又不稳定,并且存在诸如模型崩溃等问题。如果您想全面了解generative models,请查看使用W&B探索Deep Generative Modeling。
尝试训练一个简单的GAN →\rightarrow
训练指标显示在下面的媒体面板中。 您是否可以确定将用于生成新图像的模型检查点?*训练指标显示在下面的媒体面板中。
评估GAN的一种简单方法是“盯着”模型训练过程,即手动查看使用模型检查点生成的图像。您可以使用回调函数在每n
个时期记录一批生成的图像。
如果您已经查看了上面链接的colab笔记本,那么可以看到我已经自定义Keras回调以记录生成的图像。
class GeneratedImageLogger(tf.keras.callbacks.Callback):
def __init__(self, noise_dim, batch_size=32):
super(GeneratedImageLogger, self).__init__()
self.noise = tf.random.normal([batch_size, noise_dim])
def on_epoch_end(self, logs, epoch):
generated_image = generator(self.noise, training=False)
wandb.log({"gen_images": [wandb.Image(image)
for image in generated_image]})
使用回调的结果如下所示。单击下面显示的媒体面板中的⚙️ icon图标,然后更改步骤数以可视化每个时期生成的图像。
GAN评估
图1: (来源)
-
在监督式图像分类中,评估很简单。我们需要将预测输出与实际输出进行比较。
-
但是,使用GAN,您会传入一些随机噪声以获取此伪造(生成的)图像。我们希望生成的图像看起来尽可能真实我们希望生成的图像看起来尽可能真实。那么,如何精确地量化生成的图像的真实性呢? 那么,如何精确地量化生成的图像的真实性呢?或者您如何准确评估GAN?
我们将从为评估指标设置两个简单属性开始:
- 保真度(Fidelity): 我们希望我们的GAN能够生成高质量的图像。
- 多样性(Diversity): 我们的GAN应该生成训练数据集中固有的图像。
因此,我们的评估指标应同时针对这两个属性进行评估。但是,同时比较保真度和多样性可能具有挑战性,因为您到底应该比较什么?计算机视觉中广泛使用的图像比较的两种方法是:
-
像素距离:: 这是一个简单的距离指标——两个图像的像素值相减。但是,这不是一个可靠的指标。
-
特征距离: 我们使用预先训练的图像分类模型,并使用中间层的激活。此向量是图像的高级表示。用这种表示来计算距离指标会更加稳定且可靠。
现在我们已经涵盖了一些基础,我们将快速了解如何使用Frechet Inception Distance(FID)来评估GAN。
❄️ Frechet 起始距离(Frechet Inception Distance / FID)
这是用于测量真实图像与生成图像之间的特征距离的最受欢迎的度量标准之一。Frechet Distance是衡量曲线之间相似度的一种方法,其中考虑了沿曲线的点的位置和顺序。它也可用于测量两个分布之间的距离。
使用FID评估您的GAN →\rightarrow
在数学上,Frechet Distance用于计算两个“多元”正态分布之间的距离。对于“单变量”正态分布,Frechet距离为
d(X,Y)=(μX−μY)2+(σX−σY)2d(X, Y) = (μ_X - μ_Y)^2 + (σ_X - σ_Y)^2
其中μμ和σσ是正态分布的均值和标准偏差, XX,和YY是两个正态分布。
在计算机视觉尤其是GAN评估中,我们使用如上所述的特征距离。 我们使用在Imagenet数据集上预先训练的Inception V3模型。来自Inception V3模型的激活被用来汇总每个图像,这是为什么该得分的名称为“ Frechet Inception Distance”。
此激活来自倒数第二个pooling layer(如果使用TensorFlow,则为Global Average Pooling)。 我们假设形状为(2048,)
的输出向量将通过“多元”正态分布进行近似。
“多元”正态分布的Frechet Inception Distance由下式给出:
FID=∣∣μX−μY∣∣2−Tr(∑X+∑Y−2∑X∑Y)FID = ||μ_X - μ_Y||^2 - Tr(\sum_X + \sum_Y - 2\sqrt{\sum_X\sum_Y})
其中XX和YY 分别是真实的和‘假’的嵌入(来自Inception模型的激活),被假定为两个多元正态分布。 μXμ_X和μYμ_Y是向量XX 和YY的大小。TrTr 是矩阵的迹线(linearalgebra),而∑ $\sum X$ 和∑∑Y\sum_Y是向量的协方差矩阵 。
🐦 使用TensorFlow实施FID
在本节中,我们将研究如何使用FID评分的GAN评估。我们将研究相同的关键组件。
使用FID评估您的GAN→\rightarrow
-
要使用预训练的Inception V3模型,请执行以下操作:
inception_model = tf.keras.applications.InceptionV3(include_top=False, weights="imagenet", pooling='avg')
-
计算真实图像和生成图像的嵌入。请注意,GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium 的作者建议使用最小样本大小10,000来计算FID,否则会低估生成器的真实FID。
def compute_embeddings(dataloader, count): image_embeddings = [] for _ in tqdm(range(count)): images = next(iter(dataloader)) embeddings = inception_model.predict(images) image_embeddings.extend(embeddings) return np.array(image_embeddings) count = math.ceil(10000/BATCH_SIZE) # compute embeddings for real images real_image_embeddings = compute_embeddings(trainloader, count) # compute embeddings for generated images generated_image_embeddings = compute_embeddings(genloader, count) real_image_embeddings.shape, generated_image_embeddings.shape
在这里,
trainloader
和genloader
是tf.data
数据集。请查看colab笔记本以获取操作细节。 -
使用真实和生成的图像嵌入,我们将计算FID分数。
def calculate_fid(real_embeddings, generated_embeddings): # calculate mean and covariance statistics mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False) mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(generated_embeddings, rowvar=False) # calculate sum squared difference between means ssdiff = np.sum((mu1 - mu2)**2.0) # calculate sqrt of product between cov covmean = linalg.sqrtm(sigma1.dot(sigma2)) # check and correct imaginary numbers from sqrt if np.iscomplexobj(covmean): covmean = covmean.real # calculate score fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) return fid fid = calculate_fid(real_image_embeddings, generated_image_embeddings)
现在,您知道了如何计算FID分数。让我们计算GAN检查点的FID分数。我已经使用了W&B Artifacts在每5个epoch保存检查点。使用上述功能,我计算了每个检查点的FID分数。结果显示在下面的媒体面板中。
结果观察
-
使用更好的模型检查点,FID分数会降低。我们可以选择一个具有较低FID分数的模型检查点进行推理。
-
FID分数相对较高,因为Inception模型是在Imagenet(自然图像)上训练的,而我们的GAN是在FashionMNIST数据集上训练的。
🐸 FID的缺点
-
它使用了预训练的Inception模型,该模型可能无法捕获所有特征。与上述情况一样,这可能会导致FID得分较高。
-
它需要大样本量。建议的最小样本量为10,000。对于高分辨率图像(例如512x512像素),这可能在计算上昂贵且运行缓慢。
-
用于计算FID得分的统计量(均值和协方差)有限。