프레쳇 인셉션 거리 (Frechet Inception distance, FID)를 사용해 GANs 평가하는 법은 무엇인가요?
🌟 서론
생성적 대립쌍 네트워크(Generative Adversarial Networks(GAN))에 익숙하고 이를 훈련해 본 적 있다면, 추론에 사용할 모델 체크포인트(model checkpoint)에 대해 궁금하셨을 겁니다. 이미지 분류에서 모델 체크포인트를 사용하며 이는 최고의 검증 정확도를 제공합니다.
However, that is not the case for GANs. Training a GAN is hard and unstable and has issues like mode collapse, etc. If you want to learn more about generative models in general, check out Towards Deep Generative Modeling with W&B. 그러나, GAN의 경우는 그렇지 않습니다. GAN를 훈련시키는 것은 까다롭고 불안정하며 모드 붕괴(mode collapse)와 같은 문제점을 지니고 있습니다. 생성 모델(genitive models)에 대해 더 자세히 알아보시려면 Towards Deep Generative Modeling with W&B를 참조하시기 바랍니다.
간단한 GAN 직접 훈련해보기 →\rightarrow
훈련 매트릭은 아래의 미디어 패널에 표시돼 있습니다. 문제는, 훈련 매트릭을 확인하여, 새로운 이미지 생성에 사용할 모델 체크포인트를 여러분께서 결정할 수 있는지입니다.
GAN를 평가하는 가장 간단한 방법은 모델 훈련 프로세스를 “베이비 시팅” 하는 것입니다. 즉, 모델 체크포인트를 사용해서 생성된 이미지를 직접 살펴보는 것입니다. 콜백을 사용해서 생성된 이미지의 배치를 매 n
에포크에 로그할 수 있습니다.
상단에 링크된 colab notebook 을 확인하셨다면, 생성된 이미지를 로그하기 위해 사용자 지정된 Keras 콜백을 구현한 부분을 확인하실 수 있습니다.
class GeneratedImageLogger(tf.keras.callbacks.Callback):
def __init__(self, noise_dim, batch_size=32):
super(GeneratedImageLogger, self).__init__()
self.noise = tf.random.normal([batch_size, noise_dim])
def on_epoch_end(self, logs, epoch):
generated_image = generator(self.noise, training=False)
wandb.log({"gen_images": [wandb.Image(image)
for image in generated_image]})
콜백을 사용한 결과는 아래에 나타나 있습니다. 하단에 표시된 미디어 패널에서 ⚙️ icon아이콘을 클릭하고 모든 에포크에 생성된 이미지를 시각화하는 단계 수를 변경합니다.
GAN 평가
그림 1: (출처)
-
강화 학습된 이미지 분류 작업 (Supervised Image Classification)에서 평가는 간단합니다. 예측 출력을 실제 출력과 비교해야 합니다.
-
그러나 GAN을 사용하면 무작위의 노이즈를 전달하여 가짜(생성된) 이미지를 얻을 수 있습니다. 우리는 이 생성된 이미지가 최대한 실제 이미지처럼 보이기를 원합니다. 그렇다면, 어떻게 이 생성된 이미지의 사실성을 정확하게 계량화할 수 있을까요? 혹은 어떻게 GAN을 정확하게 평가할 수 있습니까?
먼저 평가 매트릭에 대한 두 가지 간단한 속성 설정부터 시작해보겠습니다:
-
Fidelity(충실도): 저희는 저희 GAN이 고품질의 이미지를 생성하기를 바랍니다.
-
Diversity(다양성): 저희 GAN은 훈련 데이터세트에 내재된 이미지를 생성해야 합니다.
따라서 평가 지표는 이 두 가지 속성 모두에 대해 평가해야 합니다. 그러나 충실도 및 다양성에 대하여 이미지를 비교하는 것은 까다로울 수 있습니다. 대체 정확히 무엇을 비교해야 하나요? 컴퓨터 비전에서 널리 사용되는 두 가지 이미지 비교 방식은 다음과 같습니다.
-
Pixel Distance(픽셀 거리): 두 이미지의 픽셀 값을 빼는 간단한 거리 측정 방식입니다. 그러나 신뢰할 만한 매트릭은 아닙니다.
-
Feature Distance(특징 거리): 저희는 사전 훈련된 이미지 분류 모델을 사용하며 중간 레이어의 활성화(activation)를 사용합니다. 이 벡터는 이미지의 상위 수준 표현(high-level representation)입니다. 이러한 표현으로 거리 매트릭을 계산하면 안정적이고 신뢰할 수 있는 매트릭을 얻을 수 있습니다.
이제 몇 가지 기본 사안에 대해서 다루었으므로, 프레쳇 인셉션 거리(FID)를 사용하여 GAN을 평가하는 법에 대해서 빠르게 살펴보겠습니다.
❄️ 프레쳇 인셉션 거리(FID)
이 메트릭은 실제 이미지와 생성된 이미지 간의 특징 거리 측정에 가장 널리 사용되는 매트릭 중 하나입니다. 프레쳇 거리Frechet Distance는 곡선을 따라는 점들(points)의 위치와 순서를 고려한 곡선 간의 유사성을 측정하는 방법입니다. 이는 두 분포 사이의 거리를 측정하는 데에도 사용됩니다.
FID를 사용하여 GAN 평가하기 →\rightarrow
수학적으로, 프레쳇 거리는 두 “다변량” 정규분포(multivariate normal distributin) 사이의 거리를 계산하는데 사용됩니다. “일변량” 정규분포(univariate normal distribution)의 경우, 프레쳇 거리는 다음과 같이 계산됩니다:
d(X,Y)=(μX−μY)2+(σX−σY)2d(X, Y) = (μ_X - μ_Y)^2 + (σ_X - σ_Y)^2
여기서 μμ와 σσ는 정규분포의 평균 및 표준 편차이며, XX,와 YY 는 두 개의 정규분포입니다.
컴퓨터 비전, 특히 GAN 평가의 맥락에서, 저희는 위에서 설명한 바와 같이 특징 거리를 사용합니다. Imagenet 데이터세트에서 사전 훈련된 Inception V3 모델을 사용하겠습니다. 각 이미지를 요약하기 위한 Inception V3 모델에서 활성화(activations)를 사용하면 스코어(score)에 “Frechet Inception Distance(프레쳇 인셉션 거리)”라는 이름이 부여됩니다.
페널티메이트 풀링 레이어(penultimate pooling layer, 끝에서 두 번째 풀링 레이어) (TensorFlow를 사용하시는 경우 글로벌 평균 풀링(Global Average Pooling))에서 이 활성화를 가져옵니다. 저희는 shape (2048, )
의 출력 벡터가 “다변량” 정규분포에 가깝다고 가정합니다.
“다변량” 정규분포에 대한 프레쳇 인셉션 거리는 다음에 의해 주어집니다:
FID=∣∣μX−μY∣∣2−Tr(∑X+∑Y−2∑X∑Y)FID = ||μ_X - μ_Y||^2 - Tr(\sum_X + \sum_Y - 2\sqrt{\sum_X\sum_Y})
여기서 XX 와 YY는 두 개의 다변량 정규분포로 가정된 실제와 가짜 임베딩(Inception 모델에서 활성화)입니다. μXμ_X 와 μYμ_Y는 벡터 XX 와 YY의 크기(magnitude)입니다. TrTr은 행렬의 대각합(trace)(linearalgebra)이며 $\sum X 와 ∑Y\sum_Y 는 벡터의 공분산 행렬(covariance matrix)입니다.
🐦 TensorFlow 사용을 통한 FID 구현
이 섹션에서, FID 스코어(score)를 사용하여 GAN 평가 파이프라인의 구현에 대해 살펴보겠습니다. 동일 핵심 요소를 살펴보겠습니다.
FID를 사용하여 GAN 평가하기 →\rightarrow
-
사전 훈련된 Inception V3 모델을 사용하려면 다음을 수행합니다:
inception_model = tf.keras.applications.InceptionV3(include_top=False, weights="imagenet", pooling='avg')
-
실제 이미지와 생성된 이미지에 대한 임베딩을 계산합니다. 참조: GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium의 저자는 최소 샘플 사이즈 10,000을 사용하여 FID를 계산할 것을 권장하고 있습니다. 그렇지 않은 경우 생성기(generator)의 실제 FID(true FID)가 과소평가됩니다.
def compute_embeddings(dataloader, count): image_embeddings = [] for _ in tqdm(range(count)): images = next(iter(dataloader)) embeddings = inception_model.predict(images) image_embeddings.extend(embeddings) return np.array(image_embeddings) count = math.ceil(10000/BATCH_SIZE) # compute embeddings for real images real_image_embeddings = compute_embeddings(trainloader, count) # compute embeddings for generated images generated_image_embeddings = compute_embeddings(genloader, count) real_image_embeddings.shape, generated_image_embeddings.shape
여기
trainloader
,genloader
및tf.data
데이터세트가 있습니다. 구현 세부 사항의 경우 colab notebook을 참조하시기 바랍니다. -
실제 및 생성된 이미지 임베딩을 사용하여, FID 스코어를 계산하겠습니다.
def calculate_fid(real_embeddings, generated_embeddings): # calculate mean and covariance statistics mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False) mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(generated_embeddings, rowvar=False) # calculate sum squared difference between means ssdiff = np.sum((mu1 - mu2)**2.0) # calculate sqrt of product between cov covmean = linalg.sqrtm(sigma1.dot(sigma2)) # check and correct imaginary numbers from sqrt if np.iscomplexobj(covmean): covmean = covmean.real # calculate score fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) return fid fid = calculate_fid(real_image_embeddings, generated_image_embeddings)
이제 FID 스코어를 계산하는 방법에 대해서 알게 되었습니다. 이제 GAN 체크포인트의 FID 스코어 계산을 해보겠습니다. 저의 경우 W&B 아티팩트를 사용하여 5개의 에포크마다 체크포인트를 저장해왔습니다. 상단에 설명된 함수를 사용해서 각 체크포인트에 대한 FID 스코어를 계산했습니다. 결과는 아래의 미디어 패널에 표시됩니다.
관측
-
더 나은 모델 체크포인트를 사용하면 FID 스코어는 낮아집니다. 추론(inference)를 위해 낮은 FID 스코어를 생성하는 모델 체크포인트를 선택할 수 있습니다.
-
GAN은 FashionMNIST 데이터세트에서 훈련되고 자연 이미지를 구성하는 Inception 모델이 Imagenet에서 훈련되었으므로 FID 스코어는 상대적으로 높습니다.
🐸 FID의 단점
-
사전 훈련된 Inception 모델을 사용하며, 이는 모든 특징을 담아내지 못할 수 있습니다. 이 경우 위의 경우와 마찬가지로 FID 스코어가 높을 수 있습니다.
-
큰 샘플 사이즈가 필요합니다. 권장되는 최소 샘플 사이즈는 10,000입니다. 고해상도 이미지(예: 512x512 픽셀)의 경우, 컴퓨터 리소스를 많이 사용하고 실행 속도가 느릴 수 있습니다.
-
FID 스코어 계산에 제한된 통계량(평균 및 공분산)이 사용됩니다.