Skip to main content

Comment évaluer les GAN en utilisant la Fréchet Inception Distance (FID) ?

Dans ce rapport, nous discuterons rapidement des pièges de l’évaluation des GAN et de la manière d’implémenter un pipeline d’évaluation FID.
Created on February 10|Last edited on February 2
Ceci est une traduction d'un article en anglais qui peut être trouvé ici.



🌟 Introduction

Si les réseaux antagonistes génératifs (Generative Adversarial Networks, GAN) vous sont familiers et que vous en avez déjà entraîné, vous vous êtes sûrement demandé quel checkpoint de modèle utiliser pour inférence. Dans la classification d’image, vous utilisez le checkpoint de modèle qui vous donne la meilleure précision de validation.

Cependant, ce n’est pas le cas pour les GAN. Entraîner un GAN peut s’avérer compliqué et instable, et présenter des problèmes comme l’effondrement de mode, etc. Si vous voulez en apprendre davantage sur les modèles génératifs de manière générale, consultez le rapport Vers une Modélisation Générative Profonde avec W&B.

Essayez d’entraîner votre propre GAN simple →\rightarrow

Les mesures d’entraînement sont affichées dans les graphiques ci-dessous. La question est de savoir si, en regardant les mesures d’entraînement, vous êtes en mesure de déterminer le checkpoint de modèle que vous utiliserez pour générer de nouvelles images.




Run set
7


Une manière naïve d’évaluer un GAN consiste à « baby-sitter » le processus d’entraînement de modèle, c’est-à-dire à visionner manuellement les images générées en utilisant des checkpoints de modèle. Vous pouvez utiliser une fonction de rappel (callback) pour charger un ensemble d’images générées toutes les n epochs.

Si vous avez consulté le notebook Colab dont j’ai mis le lien plus haut, j’ai implémenté une fonction de rappel personnalisée Keras pour enregistrer les images générées.

class GeneratedImageLogger(tf.keras.callbacks.Callback):
    def __init__(self, noise_dim, batch_size=32):
      super(GeneratedImageLogger, self).__init__()
      self.noise = tf.random.normal([batch_size, noise_dim])

    def on_epoch_end(self, logs, epoch):
      generated_image = generator(self.noise, training=False)

      wandb.log({"gen_images": [wandb.Image(image)
                            for image in generated_image]})

Les résultats de l’utilisation de la fonction de rappel sont affichés ci-dessous. Cliquez sur l’icône ⚙️dans le panneau d’images ci-dessous et changez le nombre d’étapes pour visualiser les images générées àchaque epoch.




Run set
1


Évaluation GAN

image.png

Figure 1: (Source)

  • Dans la classification d’image supervisée, l’évaluation est simple. Il faut comparer l’output prédit avec l’output réel.

  • Cependant, avec un GAN, vous passez du bruit aléatoire pour obtenir cette image fausse (générée). Nous voulons que cette image générée ait l’air aussi réaliste que possible. Donc, comment faire pour quantifier le réalisme d’une image générée ? Ou comment pouvez-vous évaluer votre GAN, avec précision ?

Nous commencerons par paramétrer deux propriétés simples pour notre mesure d’évaluation :

  • La fidélité (fidelity) : nous voulons que notre GAN génère des images de haute qualité.
  • La diversité (diversity) : notre GAN devrait générer des images qui sont inhérentes au jeu de données (dataset) d’entraînement.

Ainsi, notre mesure d’évaluation devrait être axée sur ces deux propriétés. Cependant, comparer des images sur la fidélité et la diversité peut être complexe, parce qu’une question se pose : que doit-on comparer, au juste ? Les deux approches suivantes sont largement utilisées en vision par ordinateur (computer vision) pour comparer des images :

  • La distance entre pixels : Ceci est une mesure de distance naïve où nous soustrayons les valeurs de pixel des deux images. Cependant, ce n’est pas une mesure fiable.

  • La distance entre caractéristiques : Nous utilisons un modèle de classification d’image pré-entraîné et l’activation d’une couche intermédiaire. Ce vecteur est la représentation à haut niveau de l’image. Calculer une mesure de distance avec une représentation pareille donne une mesure stable et fiable.

Maintenant que nous avons développé quelques bases, voyons rapidement comment utiliser la Fréchet Inception Distance (FID) pour évaluer les GAN.



❄️ Frechet Inception Distance(FID)

C’est l’une des mesures les plus populaires pour mesurer la distance entre les points caractéristiques des images réelles et des images générées. La distance de Fréchet est une mesure de similarité entre plusieurs courbes qui prennent en compte l’emplacement et l’ordre des points le long de ces courbes. Elle peut également être utilisée pour mesurer la distance entre deux distributions.

Évaluer votre GAN en utilisant la FID→\rightarrow

Mathématiquement la Distance de Fréchet est utilisée pour calculer la distance entre deux distributions normales "multivariées". Pour une distribution normale "univariée", la Distance de Fréchet est donnée comme,

d(X,Y)=(μX−μY)2+(σX−σY)2d(X, Y) = (μ_X - μ_Y)^2 + (σ_X - σ_Y)^2

où μ et σ représentent respectivement la moyenne et l’écart-type des distributions normales, et Xet Y représentent deux distributions normales.

Dans le contexte de la vision par ordinateur, en particulier de l’évaluation GAN, nous utilisons la distance d’élément comme décrite ci-dessus. Nous utilisons le modèle pré-entraîné Inception V3 sur le jeu de données ImageNet. L’utilisation des activations du modèle Inception V3 pour résumer chaque image donne à ce score son nom « Fréchet Inception Distance”.

Cette activation est tirée de l’avant-dernière couche de regroupement (pooling layer) – ou regroupement moyen global (Global Average Pooling) si vous utilisez TensorFlow. Nous supposons que le vecteur d’output de forme (2048, ) sera approximé par distribution normale "multivariée".

La Fréchet Inception Distance pour une distribution normale "multivariée" est donnée par,

FID=∣∣μX−μY∣∣2−Tr(∑X+∑Y−2∑X∑Y)FID = ||μ_X - μ_Y||^2 - Tr(\sum_X + \sum_Y - 2\sqrt{\sum_X\sum_Y})

XX et YY sont respectivement le vrai et le faux plongement (embedding) (activation depuis le modèle Inception), qui sont supposés être deux distributions normales multivariées. μX​ et μY sont les magnitudes des vecteurs XX et YY. TrTr est la trace de la matrice (algèbre linéaire) et ∑X​ et ∑Y sont les matrices de covariance de ces vecteurs.



🐦 Implémenter la FID en utilisant TensorFlow

Dans cette section, nous nous intéresserons à l’implémentation du pipeline d’évaluation de GAN en utilisant le score FID. Nous nous pencherons également sur les composants clefs de ce dernier.

Pour évaluer votre GAN en utilisant la FID →\rightarrow

  1. Pour utiliser le modèle Inception V3 pré-entraîné :

    inception_model = tf.keras.applications.InceptionV3(include_top=False, 
                                  weights="imagenet", 
                                  pooling='avg')
    
  2. Calculer les plongements pour les vraies images et les images générées. Notez bien que les auteurs de la publication GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium ont recommandé d’utiliser une taille d’échantillon minimum de 10 000 pour calculer la FID, sans quoi la vraie FID du générateur est sous-estimée.

    def compute_embeddings(dataloader, count):
        image_embeddings = []
    
        for _ in tqdm(range(count)):
            images = next(iter(dataloader))
            embeddings = inception_model.predict(images)
    
            image_embeddings.extend(embeddings)
    
        return np.array(image_embeddings)
    
    count = math.ceil(10000/BATCH_SIZE)
    
    # compute embeddings for real images
    real_image_embeddings = compute_embeddings(trainloader, count)
    
    # compute embeddings for generated images
    generated_image_embeddings = compute_embeddings(genloader, count)
    
    real_image_embeddings.shape, generated_image_embeddings.shape
    

    Ici, trainloader et genloader sont des datasets tf.data. Consultez le colab notebook pour obtenir les détails d’implémentation.

  3. Avec les plongements des vraies images et de celles générées, nous allons calculer le score FID.

     def calculate_fid(real_embeddings, generated_embeddings):
         # calculate mean and covariance statistics
         mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
         mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(generated_embeddings,  rowvar=False)
         # calculate sum squared difference between means
        ssdiff = np.sum((mu1 - mu2)**2.0)
        # calculate sqrt of product between cov
        covmean = linalg.sqrtm(sigma1.dot(sigma2))
        # check and correct imaginary numbers from sqrt
        if np.iscomplexobj(covmean):
           covmean = covmean.real
         # calculate score
         fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
         return fid
    
     fid = calculate_fid(real_image_embeddings, generated_image_embeddings)
    

Maintenant que vous savez comment calculer le score FID, nous allons calculer le score FID pour chaque checkpoint GAN. J’ai utilisé les artéfacts de W&B pour stocker les checkpoints toutes les 5 epochs. En utilisant les fonctions décrites ci-dessus, j’ai calculé les scores FID pour chaque checkpoint. Les résultats sont affichés dans le graphique ci-dessous.




Run set
1


Observations

  • Le score FID diminue avec le checkpoint du meilleur modèle. On peut choisir un checkpoint de modèle qui génère un score FID bas pour inférence.

  • Le score FID est relativement haut parce que le modèle Inception est entraîné sur ImageNet, qui est constitué d’images naturelles tandis que notre GAN est entraîné sur le jeu de données Fashion MNIST.

🐸 Les limites de la FID

  • Elle utilise un modèle Inception pré-entrainé, qui peut ne pas capturer tous les éléments. Cela peut résulter en un score FID élevé, comme nous l’avons vu précédemment.

  • Elle requiert une grande taille d’échantillon. La taille minimum d’échantillon recommandée est 10 000. Pour une image à haute résolution (disons 512x512 pixels), cela peut demander des ressources informatiques considérables et être lent à exécuter.

  • Des statistiques limitées (moyenne et covariance) sont utilisées pour calculer le score FID.


Iterate on AI agents and models faster. Try Weights & Biases today.