Skip to main content

조건부 디퓨전(Conditional Diffusion) 모델을 완전 처음 트레이닝하는 방법

이 모델들 중 하나를 트레이닝 시킴으로 무엇을 러닝할 수 있을까요?
Created on December 9|Last edited on February 14
이는 여기에서 볼 수 있는 영어 기사를 번역한 것이다.


DALL-E에서, 스테이블 디퓨전(Stable Diffusion)에 이르기까지, 이미지 생성은 현재 딥 러닝에서 아마 가장 흥미로운 무언가입니다. 하지만 트위터에 게시할 웃긴 이미지를 만드는 것 외에도 어떤 다른 활용 사례를 생각해 볼 수 있을까요?
최근에 저는 유튜브에서 PyTorch에서 조건부 디퓨전(Conditional Diffusion) 모델 프로그래밍에 관한 훌륭한 영상 하나를 발견했습니다. 이 코드를 더 자세히 알아보고 저희만의 모델 또한 트레이닝을 할 것이니 이 글을 읽어보시기 전에 이 영상을 볼 것을 추천 드립니다!


이 영상에서 배울 수 있는 것은 바로 지도(supervised) 데이터 세트에서 데이터를 생성하도록 모델을 트레이닝 시킬 수 있다는 것입니다.
레이블이 지정된 데이터 세트가 있으면 합성 데이터를 생성할 수 있습니다! 이는 복잡하고 값비싼 레이블링(to-label) 데이터에 매우 유용할 수 있습니다.
💡

초기 체크업: CIFAR-10

원본 코드는 CIFAR-10을 사용하여 조건부 디퓨전 모델을 트레이닝합니다. 그러니 먼저 한번 해봅시다! 시작하기 전에 코드 기반에 대한 몇 가지 개선 사항이 있습니다:
  • 검증(validation) 메트릭을 추가했습니다(테스트 데이터를 기반으로 계산).
  • 혼합 정밀 트레이닝과 멀티스레드(multithreaded) 데이터 로드를 활성화했습니다.
  • OneCycleScheduler
  • Weights 와 Biases (W&B) 로깅
이 실험을 위한 코드는https://github.com/tcapelle/Diffusion-Models-pytorch에서 확인할 수 있습니다.
💡
이 원격 측정을 통해 모델의 트레이닝 방식을 추적할 수 있었지만 다음과 같은 모델에는 충분하지 않습니다:


보시다시피 손실('train_mse')이 그다지 원활하지 않기 때문에 모델들이 아무것도 배우지 못하고 있다고 생각하실 수 있습니다. 그러나 샘플링된 이미지를 플롯하면(10개의 에포크 마다 디퓨전 추론을 실행하고 이미지를 W&B에 기록) 모델이 어떻게 계속 개선되는지 알 수 있습니다. 아래로 슬라이더를 이동시키면 시간이 지남에 따라 모델이 어떻게 개선되는지 확인할 수 있습니다. 저희 일반 모델과 EMA 복사 모델을 모두 샘플링합니다.
슬라이더 번호는 모델에서 볼 수 있는 뱃치이지만 10개의 에포크 마다 기록되었습니다.


참고: 영상에서 저자는 EMA 모델이 더 나은 출력을 낸다고 주장하지만, 저는 확신할 수 없습니다.
💡

기술적 세부사항

이 트레이닝을 가능하게 하는 수 많은 세부 사항들이 있습니다. 우선, 저희는 최신 PyTorch를 사용하여 GCP 머신에 대한 모든 트레이닝을 실시했습니다. A100(40GB)으로 넘어가기 전에 CIFAR 초기 트레이닝에 V100 16GB를 사용했습니다. 다음은 몇 가지 추가 사항입니다:
  • 이미지의 노이즈를 제거하는 UNet 모델은 사이즈가 큽니다. 많은 셀프 어텐션 레이어를 가지고 있고 계산의 양이 막대합니다. 따라서, 이미지의 해상도가 작더라도 모든 픽셀 간에 셀프 어텐션이 계산되기 때문에 이미지 크기에 있어서는 쿼드래틱(quadratic : 2차적)입니다.
  • 보다 선명한 결과를 얻기 위해 32x32 대신 64x64 이미지가 있는 CIFAR-10의 업그레이드 버전을 사용했습니다. 이 데이터 세트는 Kaggle에서 찾을 수 있습니다.
  • 저는 V100에 4가 아닌 10과 같은 뱃치 크기만 혼합 정밀도 없이 장착할 수 있었습니다.
  • 또한 네트워크의 더 깊은 병목 현상(512 x 512) 컨볼루션(convolutional) 레이어 중 하나를 억제하여 조금 더 빠르게 만들었습니다.

코드

이 문서의 전체 코드는 여기에서 찾을 수 있습니다. 이 문서에서 코드베이스의 주요 부분에 대해서만 논의할 것이며, 질문이 있는 경우 저에게 연락해주세요 (또는 댓글을 달아주세요!).

모델

기본 무조건적 디퓨전(non-conditional diffusion) 모델은 셀프 어텐션 레이어가 있는 UNet으로 구성됩니다. 저희는 다운샘플링과 업샘플링 경로가 있는 고전적인 U 스트럭쳐를 가지고 있습니다. 기존 UNet과의 주요 차이점은 위쪽 및 아래쪽 블록이 전진 패스(forward pass)에 대한 추가적인 timestep 인수를 지원한다는 것입니다. 이 작업은 timestep을 컨볼루션에 선형으로 포함시킴으로써 수행됩니다. 자세한 내용은modules.py file을 참고해주세요.
class UNet(nn.Module):
def __init__(self, c_in=3, c_out=3, time_dim=256):
super().__init__()
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64)
self.down1 = Down(64, 128)
self.sa1 = SelfAttention(128)
self.down2 = Down(128, 256)
self.sa2 = SelfAttention(256)
self.down3 = Down(256, 256)
self.sa3 = SelfAttention(256)

self.bot1 = DoubleConv(256, 256)
self.bot2 = DoubleConv(256, 256)

self.up1 = Up(512, 128)
self.sa4 = SelfAttention(128)
self.up2 = Up(256, 64)
self.sa5 = SelfAttention(64)
self.up3 = Up(128, 64)
self.sa6 = SelfAttention(64)
self.outc = nn.Conv2d(64, c_out, kernel_size=1)
def unet_forwad(self, x, t):
"Classic UNet structure with down and up branches, self attention in between convs"
x1 = self.inc(x)
x2 = self.down1(x1, t)
x2 = self.sa1(x2)
x3 = self.down2(x2, t)
x3 = self.sa2(x3)
x4 = self.down3(x3, t)
x4 = self.sa3(x4)

x4 = self.bot1(x4)
x4 = self.bot2(x4)

x = self.up1(x4, x3, t)
x = self.sa4(x)
x = self.up2(x, x2, t)
x = self.sa5(x)
x = self.up3(x, x1, t)
x = self.sa6(x)
output = self.outc(x)
return output
def forward(self, x, t):
"Positional encoding of the timestep before the blocks"
t = t.unsqueeze(-1)
t = self.pos_encoding(t, self.time_dim)
return self.unet_forwad(x, t)
조건부 모델은 거의 동일하지만 삽입 레이어를 통해 레이블을 전달하여 클래스 레이블의 인코딩을 timestep에 추가합니다. 이는 매우 간단하고 우아한 해결책입니다.
class UNet_conditional(UNet):
def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None):
super().__init__(c_in, c_out, time_dim)
if num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_dim)

def forward(self, x, t, y=None):
t = t.unsqueeze(-1)
t = self.pos_encoding(t, self.time_dim)

if y is not None:
t += self.label_emb(y)

return self.unet_forwad(x, t)

EMA 코드

지수 이동 평균(EMA: Exponential Moving Average)은 더 좋은 결과와 더 안정적인 트레이닝으로 만들기 위해 사용되는 기술입니다. 이전 반복의 모델 가중치 복사본을 유지하고 현재 반복 가중치를 (1-베타)의 인수만큼 업데이트함으로써 작동합니다.
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
self.step = 0

def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)

def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new

def step_ema(self, ema_model, model, step_start_ema=2000):
if self.step < step_start_ema:
self.reset_parameters(ema_model, model)
self.step += 1
return
self.update_model_average(ema_model, model)
self.step += 1

def reset_parameters(self, ema_model, model):
ema_model.load_state_dict(model.state_dict())

트레이닝

저희는 코드가 기능하도록 리팩터링했습니다. 트레이닝 단계는one_epoch 함수에서 수행됩니다:
def train_step(self):
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.ema.step_ema(self.ema_model, self.model)
self.scheduler.step()

def one_epoch(self, train=True, use_wandb=False):
avg_loss = 0.
if train: self.model.train()
else: self.model.eval()
pbar = progress_bar(self.train_dataloader, leave=False)
for i, (images, labels) in enumerate(pbar):
with torch.autocast("cuda") and (torch.inference_mode() if not train else torch.enable_grad()):
images = images.to(self.device)
labels = labels.to(self.device)
t = self.sample_timesteps(images.shape[0]).to(self.device)
x_t, noise = self.noise_images(images, t)
if np.random.random() < 0.1:
labels = None
predicted_noise = self.model(x_t, t, labels)
loss = self.mse(noise, predicted_noise)
avg_loss += loss
if train:
self.train_step()
if use_wandb:
wandb.log({"train_mse": loss.item(),
"learning_rate": self.scheduler.get_last_lr()[0]})
pbar.comment = f"MSE={loss.item():2.3f}"
return avg_loss.mean().item()
여기 저희 W&B 도구의 첫 번째 부분에서 트레이닝 손실과 학습률 값을 보실 수 있습니다. 이렇게 하면 사용 중인 스케줄러를 따를 수 있습니다. 샘플을 실제로 기록하기 위해 모델 추론을 수행하는 사용자 정의 함수를 정의합니다.
@torch.inference_mode()
def log_images(self):
"Log images to wandb and save them to disk"
labels = torch.arange(self.num_classes).long().to(self.device)
sampled_images = self.sample(use_ema=False, n=len(labels), labels=labels)
ema_sampled_images = self.sample(use_ema=True, n=len(labels), labels=labels)
plot_images(sampled_images) #to display on jupyter if available
# log images to wandb
wandb.log({"sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in sampled_images]})
wandb.log({"ema_sampled_images": [wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in ema_sampled_images]})
또한 모델 체크포인트를 저장하는 기능도 있습니다:
def save_model(self, run_name, epoch=-1):
"Save model locally and to wandb"
torch.save(self.model.state_dict(), os.path.join("models", run_name, f"ckpt.pt"))
torch.save(self.ema_model.state_dict(), os.path.join("models", run_name, f"ema_ckpt.pt"))
torch.save(self.optimizer.state_dict(), os.path.join("models", run_name, f"optim.pt"))
at = wandb.Artifact("model", type="model", description="Model weights for DDPM conditional", metadata={"epoch": epoch})
at.add_dir(os.path.join("models", run_name))
wandb.log_artifact(at)
모든 것이fit 기능에 잘 맞습니다
def prepare(self, args):
"Prepare the model for training"
setup_logging(args.run_name)
device = args.device
self.train_dataloader, self.val_dataloader = get_data(args)
self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=0.001)
self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=args.lr,
steps_per_epoch=len(self.train_dataloader), epochs=args.epochs)
self.mse = nn.MSELoss()
self.ema = EMA(0.995)
self.scaler = torch.cuda.amp.GradScaler()

def fit(self, args):
self.prepare(args)
for epoch in range(args.epochs):
logging.info(f"Starting epoch {epoch}:")
self.one_epoch(train=True)
## validation
if args.do_validation:
self.one_epoch(train=False)
# log predicitons
if epoch % args.log_every_epoch == 0:
self.log_images(use_wandb=args.use_wandb)

# save model
self.save_model(run_name=args.run_name, use_wandb=args.use_wandb, epoch=epoch)

이미지 샘플링

이미지를 샘플링 하려면 랜덤 노이즈에서 시작하여 반복적으로 노이즈를 제거해서 최종 이미지를 얻어야 합니다. 이 절차는 "일러스트 스테이블 디퓨전(The Illustrated Stable Diffusion)"에 매우 잘 설명되어 있습니다. 저희의 경우엔 훨씬 간단하지만 샘플링은 동일합니다. 스테이블 디퓨전 아키텍처의 바이트와 비트를 이해하는 데 관심이 있는 사람들에게는 어쨌든 좋은 읽을거리입니다.
저희의 경우엔 UNet의 출력 후 이미지가 이미 완전한 해상도이기 때문에 디코더가 필요하지 않습니다.
노이즈 제거 단계를 한 번에 하나씩 수행합니다. 저희도 똑같이 이 절차를 거치고 있지만 이미지 디코더를 사용하지 않고 있습니다.
샘플링 코드는 노이즈 스케줄러에 이어 이미지에서 노이즈를 점진적으로 제거합니다. 무작위 순수 노이즈에서 시작하여 샘플 이미지로 끝납니다. 이 코드는 약간 헷갈릴 수도 있습니다. 왜냐하면DDPM paper에 있는 방정식의 이름을 따서 이름이 지어졌기 때문입니다.
저희는 이 글의 두 번째 알고리즘을 사용하고 있습니다:
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

글꼴 생성 모델 트레이닝

레이블이 지정된 다른 데이터 세트를 사용해 보겠습니다. 이 간단한 예시를 위해선 작은 이미지를 가진 데이터 세트가 좋습니다. Kaggle에서 검색한 결과 알파벳 글씨체 데이터 세트(Alphabet character fonts dataset)가 좋은 예로 보입니다. 이 데이터 세트의 모양을 한번 살펴볼까요?
데이터 세트의 각 행은 같은 글씨체이고 A에서 Z까지의 문자가 32x32픽셀 BW 이미지로 렌더링 됩니다. 이 정보를 W&B에 신속하게 기록하고 다음을 살펴볼 수 있습니다:

Run set
10

여기서 하고자 하는 것은 바로 새로운 글씨체를 생성하기 위해 디퓨전 모델을 트레이닝시키는 것입니다! CIFAR-10과 동일한 아이디어와 조건 생성을 저희가 생산하고자 하는 실제 글자에 적용할 수 있습니다. 이전과 마찬가지로 트레이닝 중에 정기적으로 출력물을 샘플링하여 트레이닝 절차를 살펴보겠습니다.


이 모델의 단점은 글씨체 (또는 스타일)를 생성할 방법이 없다는 것입니다. 각 문자를 개별적으로 가져가서 모델을 트레이닝했기 때문에 하나씩만 문자를 생성할 수 있습니다. 따라서 레이블 A-Z (또는 W&B 🤣)를 통과하더라도 매번 독립적인 무작위 문자를 받습니다.


제가 이 노트북 에서 이 wandb.Tables를 어떻게 만들었는지 확인해 보세요. 

Fastai를 사용한 디퓨전 모델 트레이닝

조건부 디퓨전 모델을 처음부터 트레이닝하는 데 코드가 얼마나 적게 필요한지 살펴보았지만, 라이브러리를 fastai로 사용하면 훨씬 더 줄일 수 있습니다! 여러분이 지금 진행 중인 "코더들을 위한 딥 러닝 파트 2"를 따르고 제레미와 함께 생성 모델에 대해 더 많이 배워 보시는 것을 권해드립니다.
그 동안에 제가 전에 보여드린 것과 같은 CIFAR 코드는 Fastai의 언어로 된 콜백일 뿐입니다. 여러 Fastai 개발자가 디퓨전 모델의 구현을 추가하고 있는 fastdiffusion저장소를 확인해보세요.
CIFAR-10 을 트레이닝하려면 이 콜백을 Image 2 Image fastai 파이프라인에 전달하기만 하면 됩니다
class ConditionalDDPMCallback(Callback):
def __init__(self, n_steps, beta_min, beta_max, tensor_type=TensorImage):
store_attr()

def before_fit(self):
self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma = torch.sqrt(self.beta)

def before_batch_training(self):
x0 = self.xb[0] # original images, x_0
eps = self.tensor_type(torch.randn(x0.shape, device=x0.device)) # noise, x_T
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps
alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the image
self.learn.xb = (xt, t, self.yb[0]) # input to our model is noisy image and timestep
self.learn.yb = (eps,) # ground truth is the noise


def before_batch_sampling(self):
xt = self.tensor_type(self.xb[0]) # a full batch at once!
batch_size = xt.shape[0]
label = torch.arange(10, dtype=torch.long, device=xt.device).repeat(batch_size//10 + 1).flatten()[0:batch_size]
for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False):
t_batch = torch.full((batch_size,), t, device=xt.device, dtype=torch.long)
z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)
alpha_t = self.alpha[t] # get noise level at current timestep
alpha_bar_t = self.alpha_bar[t]
sigma_t = self.sigma[t]
xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, label=label)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paper
self.learn.pred = (xt,)
raise CancelBatchException

def before_batch(self):
if not hasattr(self, 'gather_preds'): self.before_batch_training()
else: self.before_batch_sampling()
그리고 나서 이 콜백을 해당Learner에게 다시 전달합니다:
model = ConditionalUnet(dim=32, channels=1, num_classes=10).cuda()
ddpm_learner = Learner(dls, model,
cbs=[ConditionalDDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02, tensor_type=TensorImageBW)],
loss_func=nn.MSELoss()).to_fp16()
fastai는 즉시 W&B를 지원하므로WandbCallback을 전달하는 것만큼 간단합니다:
with wandb.init(project="sd_from_scratch"):
ddpm_learner.fit_one_cycle(10, 1e-4, cbs=WandbCallback(log_preds=False))
보다 포괄적인 예를 보려면 이 노트북을 확인해주세요.

결론

저는 완전 초기부터 디퓨전 모델을 개인적으로 더 잘 이해하기 위함을 주된 이유로 이 작업을 했습니다. GAN은 복잡하고 트레이닝시키기 위해 많은 공학적 기술이 필요하기 때문에 저는 GAN을 결코 좋아하지 않았습니다. 그래서 신뢰할 수 있는 UNet을 사용하여 이미지를 생성하는 이 새로운 기술이 오픈 소스 세계에 강림했을 때 반드시 시도해봐야겠다 생각했습니다!
이 새롭고 강력한 도구는 레이블이 지정된 데이터를 생성하는 새로운 방법을 제공합니다. 더 강력한 사전 트레이닝 모델을 얻기 위해 데이터 세트를 사전 트레이닝하거나 보강하는 새로운 방법이 될 수 있습니다. 제가 이 가설을 시험하는 방법에 대해 몇 가지 아이디어를 가지고 있으니 놓치지마세요.
Iterate on AI agents and models faster. Try Weights & Biases today.